Compare commits

..

7 Commits

17 changed files with 1493 additions and 1526 deletions

View File

@ -162,7 +162,7 @@ jobs:
strategy:
matrix:
variant: [musa, sycl, vulkan]
variant: [musa, sycl, vulkan, cuda]
env:
REGISTRY: ghcr.io

View File

@ -36,7 +36,6 @@ option(SD_VULKAN "sd: vulkan backend" OFF)
option(SD_OPENCL "sd: opencl backend" OFF)
option(SD_SYCL "sd: sycl backend" OFF)
option(SD_MUSA "sd: musa backend" OFF)
option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF)
option(SD_BUILD_SHARED_LIBS "sd: build shared libs" OFF)
option(SD_BUILD_SHARED_GGML_LIB "sd: build ggml as a separate shared lib" OFF)
option(SD_USE_SYSTEM_GGML "sd: use system-installed GGML library" OFF)
@ -70,18 +69,12 @@ if (SD_HIPBLAS)
message("-- Use HIPBLAS as backend stable-diffusion")
set(GGML_HIP ON)
add_definitions(-DSD_USE_CUDA)
if(SD_FAST_SOFTMAX)
set(GGML_CUDA_FAST_SOFTMAX ON)
endif()
endif ()
if(SD_MUSA)
message("-- Use MUSA as backend stable-diffusion")
set(GGML_MUSA ON)
add_definitions(-DSD_USE_CUDA)
if(SD_FAST_SOFTMAX)
set(GGML_CUDA_FAST_SOFTMAX ON)
endif()
endif()
set(SD_LIB stable-diffusion)

25
Dockerfile.cuda Normal file
View File

@ -0,0 +1,25 @@
ARG CUDA_VERSION=12.6.3
ARG UBUNTU_VERSION=24.04
FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu${UBUNTU_VERSION} AS build
RUN apt-get update && apt-get install -y --no-install-recommends build-essential git ccache cmake
WORKDIR /sd.cpp
COPY . .
ARG CUDACXX=/usr/local/cuda/bin/nvcc
RUN cmake . -B ./build -DSD_CUDA=ON
RUN cmake --build ./build --config Release -j$(nproc)
FROM nvidia/cuda:${CUDA_VERSION}-cudnn-runtime-ubuntu${UBUNTU_VERSION} AS runtime
RUN apt-get update && \
apt-get install --yes --no-install-recommends libgomp1 && \
apt-get clean
COPY --from=build /sd.cpp/build/bin/sd-cli /sd-cli
COPY --from=build /sd.cpp/build/bin/sd-server /sd-server
ENTRYPOINT [ "/sd-cli" ]

View File

@ -5,6 +5,7 @@
- Download Anima
- safetensors: https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/diffusion_models
- gguf: https://huggingface.co/Bedovyy/Anima-GGUF/tree/main
- gguf Anima2: https://huggingface.co/JusteLeo/Anima2-GGUF/tree/main
- Download vae
- safetensors: https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/vae
- Download Qwen3-0.6B-Base
@ -17,4 +18,4 @@
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\anima-preview.safetensors --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_06b_base.safetensors -p "a lovely cat holding a sign says 'anima.cpp'" --cfg-scale 6.0 --sampling-method euler -v --offload-to-cpu --diffusion-fa
```
<img alt="anima image example" src="../assets/anima/example.png" />
<img alt="anima image example" src="../assets/anima/example.png" />

View File

@ -80,7 +80,7 @@ Uses Taylor series approximation to predict block outputs:
Combines DBCache and TaylorSeer:
```bash
--cache-mode cache-dit --cache-preset fast
--cache-mode cache-dit
```
#### Parameters
@ -92,14 +92,6 @@ Combines DBCache and TaylorSeer:
| `threshold` | L1 residual difference threshold | 0.08 |
| `warmup` | Steps before caching starts | 8 |
#### Presets
Available presets: `slow`, `medium`, `fast`, `ultra` (or `s`, `m`, `f`, `u`).
```bash
--cache-mode cache-dit --cache-preset fast
```
#### SCM Options
Steps Computation Mask controls which steps can be cached:

View File

@ -139,12 +139,11 @@ Generation Options:
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)
--cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level),
'spectrum' (UNET Chebyshev+Taylor forecasting)
'spectrum' (UNET/DiT Chebyshev+Taylor forecasting)
--cache-option named cache params (key=value format, comma-separated). easycache/ucache:
threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=;
spectrum: w=,m=,lam=,window=,flex=,warmup=,stop=. Examples:
"threshold=0.25" or "threshold=1.5,reset=0" or "w=0.4,window=2"
--cache-preset cache-dit preset: 'slow'/'s', 'medium'/'m', 'fast'/'f', 'ultra'/'u'
--scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g., "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache
--scm-policy SCM policy: 'dynamic' (default) or 'static'
```

View File

@ -1047,7 +1047,6 @@ struct SDGenerationParams {
std::string cache_mode;
std::string cache_option;
std::string cache_preset;
std::string scm_mask;
bool scm_policy_dynamic = true;
sd_cache_params_t cache_params{};
@ -1461,21 +1460,6 @@ struct SDGenerationParams {
return 1;
};
auto on_cache_preset_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
}
cache_preset = argv_to_utf8(index, argv);
if (cache_preset != "slow" && cache_preset != "s" && cache_preset != "S" &&
cache_preset != "medium" && cache_preset != "m" && cache_preset != "M" &&
cache_preset != "fast" && cache_preset != "f" && cache_preset != "F" &&
cache_preset != "ultra" && cache_preset != "u" && cache_preset != "U") {
fprintf(stderr, "error: invalid cache preset '%s', must be 'slow'/'s', 'medium'/'m', 'fast'/'f', or 'ultra'/'u'\n", cache_preset.c_str());
return -1;
}
return 1;
};
options.manual_options = {
{"-s",
"--seed",
@ -1513,16 +1497,12 @@ struct SDGenerationParams {
on_ref_image_arg},
{"",
"--cache-mode",
"caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level)",
"caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level), 'spectrum' (UNET/DiT Chebyshev+Taylor forecasting)",
on_cache_mode_arg},
{"",
"--cache-option",
"named cache params (key=value format, comma-separated). easycache/ucache: threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=. Examples: \"threshold=0.25\" or \"threshold=1.5,reset=0\"",
"named cache params (key=value format, comma-separated). easycache/ucache: threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=; spectrum: w=,m=,lam=,window=,flex=,warmup=,stop=. Examples: \"threshold=0.25\" or \"threshold=1.5,reset=0\"",
on_cache_option_arg},
{"",
"--cache-preset",
"cache-dit preset: 'slow'/'s', 'medium'/'m', 'fast'/'f', 'ultra'/'u'",
on_cache_preset_arg},
{"",
"--scm-mask",
"SCM steps mask for cache-dit: comma-separated 0/1 (e.g., \"1,1,1,0,0,1,0,0,1,0\") - 1=compute, 0=can cache",
@ -1575,7 +1555,6 @@ struct SDGenerationParams {
load_if_exists("negative_prompt", negative_prompt);
load_if_exists("cache_mode", cache_mode);
load_if_exists("cache_option", cache_option);
load_if_exists("cache_preset", cache_preset);
load_if_exists("scm_mask", scm_mask);
load_if_exists("clip_skip", clip_skip);
@ -1810,48 +1789,17 @@ struct SDGenerationParams {
if (!cache_mode.empty()) {
if (cache_mode == "easycache") {
cache_params.mode = SD_CACHE_EASYCACHE;
cache_params.reuse_threshold = 0.2f;
cache_params.start_percent = 0.15f;
cache_params.end_percent = 0.95f;
cache_params.error_decay_rate = 1.0f;
cache_params.use_relative_threshold = true;
cache_params.reset_error_on_compute = true;
cache_params.mode = SD_CACHE_EASYCACHE;
} else if (cache_mode == "ucache") {
cache_params.mode = SD_CACHE_UCACHE;
cache_params.reuse_threshold = 1.0f;
cache_params.start_percent = 0.15f;
cache_params.end_percent = 0.95f;
cache_params.error_decay_rate = 1.0f;
cache_params.use_relative_threshold = true;
cache_params.reset_error_on_compute = true;
cache_params.mode = SD_CACHE_UCACHE;
} else if (cache_mode == "dbcache") {
cache_params.mode = SD_CACHE_DBCACHE;
cache_params.Fn_compute_blocks = 8;
cache_params.Bn_compute_blocks = 0;
cache_params.residual_diff_threshold = 0.08f;
cache_params.max_warmup_steps = 8;
cache_params.mode = SD_CACHE_DBCACHE;
} else if (cache_mode == "taylorseer") {
cache_params.mode = SD_CACHE_TAYLORSEER;
cache_params.Fn_compute_blocks = 8;
cache_params.Bn_compute_blocks = 0;
cache_params.residual_diff_threshold = 0.08f;
cache_params.max_warmup_steps = 8;
cache_params.mode = SD_CACHE_TAYLORSEER;
} else if (cache_mode == "cache-dit") {
cache_params.mode = SD_CACHE_CACHE_DIT;
cache_params.Fn_compute_blocks = 8;
cache_params.Bn_compute_blocks = 0;
cache_params.residual_diff_threshold = 0.08f;
cache_params.max_warmup_steps = 8;
cache_params.mode = SD_CACHE_CACHE_DIT;
} else if (cache_mode == "spectrum") {
cache_params.mode = SD_CACHE_SPECTRUM;
cache_params.spectrum_w = 0.40f;
cache_params.spectrum_m = 3;
cache_params.spectrum_lam = 1.0f;
cache_params.spectrum_window_size = 2;
cache_params.spectrum_flex_window = 0.50f;
cache_params.spectrum_warmup_steps = 4;
cache_params.spectrum_stop_percent = 0.9f;
cache_params.mode = SD_CACHE_SPECTRUM;
}
if (!cache_option.empty()) {

View File

@ -129,11 +129,10 @@ Default Generation Options:
--skip-layers layers to skip for SLG steps (default: [7,8,9])
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)
--cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level)
--cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level), 'spectrum' (UNET/DiT Chebyshev+Taylor forecasting)
--cache-option named cache params (key=value format, comma-separated). easycache/ucache:
threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=. Examples:
"threshold=0.25" or "threshold=1.5,reset=0"
--cache-preset cache-dit preset: 'slow'/'s', 'medium'/'m', 'fast'/'f', 'ultra'/'u'
--scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g., "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache
--scm-policy SCM policy: 'dynamic' (default) or 'static'
```

930
src/auto_encoder_kl.hpp Normal file
View File

@ -0,0 +1,930 @@
#ifndef __AUTO_ENCODER_KL_HPP__
#define __AUTO_ENCODER_KL_HPP__
#include "vae.hpp"
/*================================================== AutoEncoderKL ===================================================*/
#define VAE_GRAPH_SIZE 20480
class ResnetBlock : public UnaryBlock {
protected:
int64_t in_channels;
int64_t out_channels;
public:
ResnetBlock(int64_t in_channels,
int64_t out_channels)
: in_channels(in_channels),
out_channels(out_channels) {
// temb_channels is always 0
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(out_channels));
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
if (out_channels != in_channels) {
blocks["nin_shortcut"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, {1, 1}));
}
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
// x: [N, in_channels, h, w]
// t_emb is always None
auto norm1 = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm1"]);
auto conv1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv1"]);
auto norm2 = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm2"]);
auto conv2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv2"]);
auto h = x;
h = norm1->forward(ctx, h);
h = ggml_silu_inplace(ctx->ggml_ctx, h); // swish
h = conv1->forward(ctx, h);
// return h;
h = norm2->forward(ctx, h);
h = ggml_silu_inplace(ctx->ggml_ctx, h); // swish
// dropout, skip for inference
h = conv2->forward(ctx, h);
// skip connection
if (out_channels != in_channels) {
auto nin_shortcut = std::dynamic_pointer_cast<Conv2d>(blocks["nin_shortcut"]);
x = nin_shortcut->forward(ctx, x); // [N, out_channels, h, w]
}
h = ggml_add(ctx->ggml_ctx, h, x);
return h; // [N, out_channels, h, w]
}
};
class AttnBlock : public UnaryBlock {
protected:
int64_t in_channels;
bool use_linear;
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") {
auto iter = tensor_storage_map.find(prefix + "proj_out.weight");
if (iter != tensor_storage_map.end()) {
if (iter->second.n_dims == 4 && use_linear) {
use_linear = false;
blocks["q"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
blocks["k"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
blocks["v"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
blocks["proj_out"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
} else if (iter->second.n_dims == 2 && !use_linear) {
use_linear = true;
blocks["q"] = std::make_shared<Linear>(in_channels, in_channels);
blocks["k"] = std::make_shared<Linear>(in_channels, in_channels);
blocks["v"] = std::make_shared<Linear>(in_channels, in_channels);
blocks["proj_out"] = std::make_shared<Linear>(in_channels, in_channels);
}
}
}
public:
AttnBlock(int64_t in_channels, bool use_linear)
: in_channels(in_channels), use_linear(use_linear) {
blocks["norm"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
if (use_linear) {
blocks["q"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
blocks["k"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
blocks["v"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
} else {
blocks["q"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
blocks["k"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
blocks["v"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
}
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
// x: [N, in_channels, h, w]
auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]);
auto q_proj = std::dynamic_pointer_cast<UnaryBlock>(blocks["q"]);
auto k_proj = std::dynamic_pointer_cast<UnaryBlock>(blocks["k"]);
auto v_proj = std::dynamic_pointer_cast<UnaryBlock>(blocks["v"]);
auto proj_out = std::dynamic_pointer_cast<UnaryBlock>(blocks["proj_out"]);
auto h_ = norm->forward(ctx, x);
const int64_t n = h_->ne[3];
const int64_t c = h_->ne[2];
const int64_t h = h_->ne[1];
const int64_t w = h_->ne[0];
ggml_tensor* q;
ggml_tensor* k;
ggml_tensor* v;
if (use_linear) {
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 2, 0, 3)); // [N, h, w, in_channels]
h_ = ggml_reshape_3d(ctx->ggml_ctx, h_, c, h * w, n); // [N, h * w, in_channels]
q = q_proj->forward(ctx, h_); // [N, h * w, in_channels]
k = k_proj->forward(ctx, h_); // [N, h * w, in_channels]
v = v_proj->forward(ctx, h_); // [N, h * w, in_channels]
} else {
q = q_proj->forward(ctx, h_); // [N, in_channels, h, w]
q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels]
q = ggml_reshape_3d(ctx->ggml_ctx, q, c, h * w, n); // [N, h * w, in_channels]
k = k_proj->forward(ctx, h_); // [N, in_channels, h, w]
k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels]
k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [N, h * w, in_channels]
v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
v = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, v, 1, 2, 0, 3)); // [N, h, w, in_channels]
v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels]
}
h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, ctx->flash_attn_enabled);
if (use_linear) {
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w]
} else {
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w]
h_ = proj_out->forward(ctx, h_); // [N, in_channels, h, w]
}
h_ = ggml_add(ctx->ggml_ctx, h_, x);
return h_;
}
};
class AE3DConv : public Conv2d {
public:
AE3DConv(int64_t in_channels,
int64_t out_channels,
std::pair<int, int> kernel_size,
int video_kernel_size = 3,
std::pair<int, int> stride = {1, 1},
std::pair<int, int> padding = {0, 0},
std::pair<int, int> dilation = {1, 1},
bool bias = true)
: Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias) {
int kernel_padding = video_kernel_size / 2;
blocks["time_mix_conv"] = std::shared_ptr<GGMLBlock>(new Conv3d(out_channels,
out_channels,
{video_kernel_size, 1, 1},
{1, 1, 1},
{kernel_padding, 0, 0}));
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x) override {
// timesteps always None
// skip_video always False
// x: [N, IC, IH, IW]
// result: [N, OC, OH, OW]
auto time_mix_conv = std::dynamic_pointer_cast<Conv3d>(blocks["time_mix_conv"]);
x = Conv2d::forward(ctx, x);
// timesteps = x.shape[0]
// x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
// x = conv3d(x)
// return rearrange(x, "b c t h w -> (b t) c h w")
int64_t T = x->ne[3];
int64_t B = x->ne[3] / T;
int64_t C = x->ne[2];
int64_t H = x->ne[1];
int64_t W = x->ne[0];
x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
x = time_mix_conv->forward(ctx, x); // [B, OC, T, OH * OW]
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
return x; // [B*T, OC, OH, OW]
}
};
class VideoResnetBlock : public ResnetBlock {
protected:
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
enum ggml_type wtype = get_type(prefix + "mix_factor", tensor_storage_map, GGML_TYPE_F32);
params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1);
}
float get_alpha() {
float alpha = ggml_ext_backend_tensor_get_f32(params["mix_factor"]);
return sigmoid(alpha);
}
public:
VideoResnetBlock(int64_t in_channels,
int64_t out_channels,
int video_kernel_size = 3)
: ResnetBlock(in_channels, out_channels) {
// merge_strategy is always learned
blocks["time_stack"] = std::shared_ptr<GGMLBlock>(new ResBlock(out_channels, 0, out_channels, {video_kernel_size, 1}, 3, false, true));
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
// x: [N, in_channels, h, w] aka [b*t, in_channels, h, w]
// return: [N, out_channels, h, w] aka [b*t, out_channels, h, w]
// t_emb is always None
// skip_video is always False
// timesteps is always None
auto time_stack = std::dynamic_pointer_cast<ResBlock>(blocks["time_stack"]);
x = ResnetBlock::forward(ctx, x); // [N, out_channels, h, w]
// return x;
int64_t T = x->ne[3];
int64_t B = x->ne[3] / T;
int64_t C = x->ne[2];
int64_t H = x->ne[1];
int64_t W = x->ne[0];
x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
auto x_mix = x;
x = time_stack->forward(ctx, x); // b t c (h w)
float alpha = get_alpha();
x = ggml_add(ctx->ggml_ctx,
ggml_ext_scale(ctx->ggml_ctx, x, alpha),
ggml_ext_scale(ctx->ggml_ctx, x_mix, 1.0f - alpha));
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
return x;
}
};
// ldm.modules.diffusionmodules.model.Encoder
class Encoder : public GGMLBlock {
protected:
int ch = 128;
std::vector<int> ch_mult = {1, 2, 4, 4};
int num_res_blocks = 2;
int in_channels = 3;
int z_channels = 4;
bool double_z = true;
public:
Encoder(int ch,
std::vector<int> ch_mult,
int num_res_blocks,
int in_channels,
int z_channels,
bool double_z = true,
bool use_linear_projection = false)
: ch(ch),
ch_mult(ch_mult),
num_res_blocks(num_res_blocks),
in_channels(in_channels),
z_channels(z_channels),
double_z(double_z) {
blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, ch, {3, 3}, {1, 1}, {1, 1}));
size_t num_resolutions = ch_mult.size();
int block_in = 1;
for (int i = 0; i < num_resolutions; i++) {
if (i == 0) {
block_in = ch;
} else {
block_in = ch * ch_mult[i - 1];
}
int block_out = ch * ch_mult[i];
for (int j = 0; j < num_res_blocks; j++) {
std::string name = "down." + std::to_string(i) + ".block." + std::to_string(j);
blocks[name] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_out));
block_in = block_out;
}
if (i != num_resolutions - 1) {
std::string name = "down." + std::to_string(i) + ".downsample";
blocks[name] = std::shared_ptr<GGMLBlock>(new DownSampleBlock(block_in, block_in, true));
}
}
blocks["mid.block_1"] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_in));
blocks["mid.attn_1"] = std::shared_ptr<GGMLBlock>(new AttnBlock(block_in, use_linear_projection));
blocks["mid.block_2"] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_in));
blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(block_in));
blocks["conv_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(block_in, double_z ? z_channels * 2 : z_channels, {3, 3}, {1, 1}, {1, 1}));
}
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
// x: [N, in_channels, h, w]
auto conv_in = std::dynamic_pointer_cast<Conv2d>(blocks["conv_in"]);
auto mid_block_1 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_1"]);
auto mid_attn_1 = std::dynamic_pointer_cast<AttnBlock>(blocks["mid.attn_1"]);
auto mid_block_2 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_2"]);
auto norm_out = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm_out"]);
auto conv_out = std::dynamic_pointer_cast<Conv2d>(blocks["conv_out"]);
auto h = conv_in->forward(ctx, x); // [N, ch, h, w]
// downsampling
size_t num_resolutions = ch_mult.size();
for (int i = 0; i < num_resolutions; i++) {
for (int j = 0; j < num_res_blocks; j++) {
std::string name = "down." + std::to_string(i) + ".block." + std::to_string(j);
auto down_block = std::dynamic_pointer_cast<ResnetBlock>(blocks[name]);
h = down_block->forward(ctx, h);
}
if (i != num_resolutions - 1) {
std::string name = "down." + std::to_string(i) + ".downsample";
auto down_sample = std::dynamic_pointer_cast<DownSampleBlock>(blocks[name]);
h = down_sample->forward(ctx, h);
}
}
// middle
h = mid_block_1->forward(ctx, h);
h = mid_attn_1->forward(ctx, h);
h = mid_block_2->forward(ctx, h); // [N, block_in, h, w]
// end
h = norm_out->forward(ctx, h);
h = ggml_silu_inplace(ctx->ggml_ctx, h); // nonlinearity/swish
h = conv_out->forward(ctx, h); // [N, z_channels*2, h, w]
return h;
}
};
// ldm.modules.diffusionmodules.model.Decoder
class Decoder : public GGMLBlock {
protected:
int ch = 128;
int out_ch = 3;
std::vector<int> ch_mult = {1, 2, 4, 4};
int num_res_blocks = 2;
int z_channels = 4;
bool video_decoder = false;
int video_kernel_size = 3;
virtual std::shared_ptr<GGMLBlock> get_conv_out(int64_t in_channels,
int64_t out_channels,
std::pair<int, int> kernel_size,
std::pair<int, int> stride = {1, 1},
std::pair<int, int> padding = {0, 0}) {
if (video_decoder) {
return std::shared_ptr<GGMLBlock>(new AE3DConv(in_channels, out_channels, kernel_size, video_kernel_size, stride, padding));
} else {
return std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, kernel_size, stride, padding));
}
}
virtual std::shared_ptr<GGMLBlock> get_resnet_block(int64_t in_channels,
int64_t out_channels) {
if (video_decoder) {
return std::shared_ptr<GGMLBlock>(new VideoResnetBlock(in_channels, out_channels, video_kernel_size));
} else {
return std::shared_ptr<GGMLBlock>(new ResnetBlock(in_channels, out_channels));
}
}
public:
Decoder(int ch,
int out_ch,
std::vector<int> ch_mult,
int num_res_blocks,
int z_channels,
bool use_linear_projection = false,
bool video_decoder = false,
int video_kernel_size = 3)
: ch(ch),
out_ch(out_ch),
ch_mult(ch_mult),
num_res_blocks(num_res_blocks),
z_channels(z_channels),
video_decoder(video_decoder),
video_kernel_size(video_kernel_size) {
int num_resolutions = static_cast<int>(ch_mult.size());
int block_in = ch * ch_mult[num_resolutions - 1];
blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, block_in, {3, 3}, {1, 1}, {1, 1}));
blocks["mid.block_1"] = get_resnet_block(block_in, block_in);
blocks["mid.attn_1"] = std::shared_ptr<GGMLBlock>(new AttnBlock(block_in, use_linear_projection));
blocks["mid.block_2"] = get_resnet_block(block_in, block_in);
for (int i = num_resolutions - 1; i >= 0; i--) {
int mult = ch_mult[i];
int block_out = ch * mult;
for (int j = 0; j < num_res_blocks + 1; j++) {
std::string name = "up." + std::to_string(i) + ".block." + std::to_string(j);
blocks[name] = get_resnet_block(block_in, block_out);
block_in = block_out;
}
if (i != 0) {
std::string name = "up." + std::to_string(i) + ".upsample";
blocks[name] = std::shared_ptr<GGMLBlock>(new UpSampleBlock(block_in, block_in));
}
}
blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(block_in));
blocks["conv_out"] = get_conv_out(block_in, out_ch, {3, 3}, {1, 1}, {1, 1});
}
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
// z: [N, z_channels, h, w]
// alpha is always 0
// merge_strategy is always learned
// time_mode is always conv-only, so we need to replace conv_out_op/resnet_op to AE3DConv/VideoResBlock
// AttnVideoBlock will not be used
auto conv_in = std::dynamic_pointer_cast<Conv2d>(blocks["conv_in"]);
auto mid_block_1 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_1"]);
auto mid_attn_1 = std::dynamic_pointer_cast<AttnBlock>(blocks["mid.attn_1"]);
auto mid_block_2 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_2"]);
auto norm_out = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm_out"]);
auto conv_out = std::dynamic_pointer_cast<Conv2d>(blocks["conv_out"]);
// conv_in
auto h = conv_in->forward(ctx, z); // [N, block_in, h, w]
// middle
h = mid_block_1->forward(ctx, h);
// return h;
h = mid_attn_1->forward(ctx, h);
h = mid_block_2->forward(ctx, h); // [N, block_in, h, w]
// upsampling
int num_resolutions = static_cast<int>(ch_mult.size());
for (int i = num_resolutions - 1; i >= 0; i--) {
for (int j = 0; j < num_res_blocks + 1; j++) {
std::string name = "up." + std::to_string(i) + ".block." + std::to_string(j);
auto up_block = std::dynamic_pointer_cast<ResnetBlock>(blocks[name]);
h = up_block->forward(ctx, h);
}
if (i != 0) {
std::string name = "up." + std::to_string(i) + ".upsample";
auto up_sample = std::dynamic_pointer_cast<UpSampleBlock>(blocks[name]);
h = up_sample->forward(ctx, h);
}
}
h = norm_out->forward(ctx, h);
h = ggml_silu_inplace(ctx->ggml_ctx, h); // nonlinearity/swish
h = conv_out->forward(ctx, h); // [N, out_ch, h*8, w*8]
return h;
}
};
// ldm.models.autoencoder.AutoencoderKL
class AutoEncoderKLModel : public GGMLBlock {
protected:
SDVersion version;
bool decode_only = true;
bool use_video_decoder = false;
bool use_quant = true;
int embed_dim = 4;
struct {
int z_channels = 4;
int resolution = 256;
int in_channels = 3;
int out_ch = 3;
int ch = 128;
std::vector<int> ch_mult = {1, 2, 4, 4};
int num_res_blocks = 2;
bool double_z = true;
} dd_config;
public:
AutoEncoderKLModel(SDVersion version = VERSION_SD1,
bool decode_only = true,
bool use_linear_projection = false,
bool use_video_decoder = false)
: version(version), decode_only(decode_only), use_video_decoder(use_video_decoder) {
if (sd_version_is_dit(version)) {
if (sd_version_is_flux2(version)) {
dd_config.z_channels = 32;
embed_dim = 32;
} else {
use_quant = false;
dd_config.z_channels = 16;
}
}
if (use_video_decoder) {
use_quant = false;
}
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder(dd_config.ch,
dd_config.out_ch,
dd_config.ch_mult,
dd_config.num_res_blocks,
dd_config.z_channels,
use_linear_projection,
use_video_decoder));
if (use_quant) {
blocks["post_quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(dd_config.z_channels,
embed_dim,
{1, 1}));
}
if (!decode_only) {
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new Encoder(dd_config.ch,
dd_config.ch_mult,
dd_config.num_res_blocks,
dd_config.in_channels,
dd_config.z_channels,
dd_config.double_z,
use_linear_projection));
if (use_quant) {
int factor = dd_config.double_z ? 2 : 1;
blocks["quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(embed_dim * factor,
dd_config.z_channels * factor,
{1, 1}));
}
}
}
struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
// z: [N, z_channels, h, w]
if (sd_version_is_flux2(version)) {
// [N, C*p*p, h, w] -> [N, C, h*p, w*p]
int64_t p = 2;
int64_t N = z->ne[3];
int64_t C = z->ne[2] / p / p;
int64_t h = z->ne[1];
int64_t w = z->ne[0];
int64_t H = h * p;
int64_t W = w * p;
z = ggml_reshape_4d(ctx->ggml_ctx, z, w * h, p * p, C, N); // [N, C, p*p, h*w]
z = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, z, 1, 0, 2, 3)); // [N, C, h*w, p*p]
z = ggml_reshape_4d(ctx->ggml_ctx, z, p, p, w, h * C * N); // [N*C*h, w, p, p]
z = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, z, 0, 2, 1, 3)); // [N*C*h, p, w, p]
z = ggml_reshape_4d(ctx->ggml_ctx, z, W, H, C, N); // [N, C, h*p, w*p]
}
if (use_quant) {
auto post_quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["post_quant_conv"]);
z = post_quant_conv->forward(ctx, z); // [N, z_channels, h, w]
}
auto decoder = std::dynamic_pointer_cast<Decoder>(blocks["decoder"]);
ggml_set_name(z, "bench-start");
auto h = decoder->forward(ctx, z);
ggml_set_name(h, "bench-end");
return h;
}
struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
// x: [N, in_channels, h, w]
auto encoder = std::dynamic_pointer_cast<Encoder>(blocks["encoder"]);
auto z = encoder->forward(ctx, x); // [N, 2*z_channels, h/8, w/8]
if (use_quant) {
auto quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["quant_conv"]);
z = quant_conv->forward(ctx, z); // [N, 2*embed_dim, h/8, w/8]
}
if (sd_version_is_flux2(version)) {
z = ggml_ext_chunk(ctx->ggml_ctx, z, 2, 2)[0];
// [N, C, H, W] -> [N, C*p*p, H/p, W/p]
int64_t p = 2;
int64_t N = z->ne[3];
int64_t C = z->ne[2];
int64_t H = z->ne[1];
int64_t W = z->ne[0];
int64_t h = H / p;
int64_t w = W / p;
z = ggml_reshape_4d(ctx->ggml_ctx, z, p, w, p, h * C * N); // [N*C*h, p, w, p]
z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 0, 2, 1, 3)); // [N*C*h, w, p, p]
z = ggml_reshape_4d(ctx->ggml_ctx, z, p * p, w * h, C, N); // [N, C, h*w, p*p]
z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 1, 0, 2, 3)); // [N, C, p*p, h*w]
z = ggml_reshape_4d(ctx->ggml_ctx, z, w, h, p * p * C, N); // [N, C*p*p, h*w]
}
return z;
}
int get_encoder_output_channels() {
int factor = dd_config.double_z ? 2 : 1;
return dd_config.z_channels * factor;
}
};
struct AutoEncoderKL : public VAE {
float scale_factor = 1.f;
float shift_factor = 0.f;
bool decode_only = true;
AutoEncoderKLModel ae;
AutoEncoderKL(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map,
const std::string prefix,
bool decode_only = false,
bool use_video_decoder = false,
SDVersion version = VERSION_SD1)
: decode_only(decode_only), VAE(version, backend, offload_params_to_cpu) {
if (sd_version_is_sd1(version) || sd_version_is_sd2(version)) {
scale_factor = 0.18215f;
shift_factor = 0.f;
} else if (sd_version_is_sdxl(version)) {
scale_factor = 0.13025f;
shift_factor = 0.f;
} else if (sd_version_is_sd3(version)) {
scale_factor = 1.5305f;
shift_factor = 0.0609f;
} else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) {
scale_factor = 0.3611f;
shift_factor = 0.1159f;
} else if (sd_version_is_flux2(version)) {
scale_factor = 1.0f;
shift_factor = 0.f;
}
bool use_linear_projection = false;
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (!starts_with(name, prefix)) {
continue;
}
if (ends_with(name, "attn_1.proj_out.weight")) {
if (tensor_storage.n_dims == 2) {
use_linear_projection = true;
}
break;
}
}
ae = AutoEncoderKLModel(version, decode_only, use_linear_projection, use_video_decoder);
ae.init(params_ctx, tensor_storage_map, prefix);
}
void set_conv2d_scale(float scale) override {
std::vector<GGMLBlock*> blocks;
ae.get_all_blocks(blocks);
for (auto block : blocks) {
if (block->get_desc() == "Conv2d") {
auto conv_block = (Conv2d*)block;
conv_block->set_scale(scale);
}
}
}
std::string get_desc() override {
return "vae";
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) override {
ae.get_param_tensors(tensors, prefix);
}
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
z = to_backend(z);
auto runner_ctx = get_context();
struct ggml_tensor* out = decode_graph ? ae.decode(&runner_ctx, z) : ae.encode(&runner_ctx, z);
ggml_build_forward_expand(gf, out);
return gf;
}
bool _compute(const int n_threads,
struct ggml_tensor* z,
bool decode_graph,
struct ggml_tensor** output,
struct ggml_context* output_ctx = nullptr) override {
GGML_ASSERT(!decode_only || decode_graph);
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(z, decode_graph);
};
// ggml_set_f32(z, 0.5f);
// print_ggml_tensor(z);
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
}
ggml_tensor* gaussian_latent_sample(ggml_context* work_ctx, ggml_tensor* moments, std::shared_ptr<RNG> rng) {
// ldm.modules.distributions.distributions.DiagonalGaussianDistribution.sample
ggml_tensor* latents = ggml_new_tensor_4d(work_ctx, moments->type, moments->ne[0], moments->ne[1], moments->ne[2] / 2, moments->ne[3]);
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, latents);
ggml_ext_im_set_randn_f32(noise, rng);
{
float mean = 0;
float logvar = 0;
float value = 0;
float std_ = 0;
for (int i = 0; i < latents->ne[3]; i++) {
for (int j = 0; j < latents->ne[2]; j++) {
for (int k = 0; k < latents->ne[1]; k++) {
for (int l = 0; l < latents->ne[0]; l++) {
mean = ggml_ext_tensor_get_f32(moments, l, k, j, i);
logvar = ggml_ext_tensor_get_f32(moments, l, k, j + (int)latents->ne[2], i);
logvar = std::max(-30.0f, std::min(logvar, 20.0f));
std_ = std::exp(0.5f * logvar);
value = mean + std_ * ggml_ext_tensor_get_f32(noise, l, k, j, i);
// printf("%d %d %d %d -> %f\n", i, j, k, l, value);
ggml_ext_tensor_set_f32(latents, value, l, k, j, i);
}
}
}
}
}
return latents;
}
ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr<RNG> rng) {
if (sd_version_is_flux2(version)) {
return vae_output;
} else if (version == VERSION_SD1_PIX2PIX) {
return ggml_view_3d(work_ctx,
vae_output,
vae_output->ne[0],
vae_output->ne[1],
vae_output->ne[2] / 2,
vae_output->nb[1],
vae_output->nb[2],
0);
} else {
return gaussian_latent_sample(work_ctx, vae_output, rng);
}
}
void get_latents_mean_std_vec(ggml_tensor* latents, int channel_dim, std::vector<float>& latents_mean_vec, std::vector<float>& latents_std_vec) {
// flux2
if (sd_version_is_flux2(version)) {
GGML_ASSERT(latents->ne[channel_dim] == 128);
latents_mean_vec = {-0.0676f, -0.0715f, -0.0753f, -0.0745f, 0.0223f, 0.0180f, 0.0142f, 0.0184f,
-0.0001f, -0.0063f, -0.0002f, -0.0031f, -0.0272f, -0.0281f, -0.0276f, -0.0290f,
-0.0769f, -0.0672f, -0.0902f, -0.0892f, 0.0168f, 0.0152f, 0.0079f, 0.0086f,
0.0083f, 0.0015f, 0.0003f, -0.0043f, -0.0439f, -0.0419f, -0.0438f, -0.0431f,
-0.0102f, -0.0132f, -0.0066f, -0.0048f, -0.0311f, -0.0306f, -0.0279f, -0.0180f,
0.0030f, 0.0015f, 0.0126f, 0.0145f, 0.0347f, 0.0338f, 0.0337f, 0.0283f,
0.0020f, 0.0047f, 0.0047f, 0.0050f, 0.0123f, 0.0081f, 0.0081f, 0.0146f,
0.0681f, 0.0679f, 0.0767f, 0.0732f, -0.0462f, -0.0474f, -0.0392f, -0.0511f,
-0.0528f, -0.0477f, -0.0470f, -0.0517f, -0.0317f, -0.0316f, -0.0345f, -0.0283f,
0.0510f, 0.0445f, 0.0578f, 0.0458f, -0.0412f, -0.0458f, -0.0487f, -0.0467f,
-0.0088f, -0.0106f, -0.0088f, -0.0046f, -0.0376f, -0.0432f, -0.0436f, -0.0499f,
0.0118f, 0.0166f, 0.0203f, 0.0279f, 0.0113f, 0.0129f, 0.0016f, 0.0072f,
-0.0118f, -0.0018f, -0.0141f, -0.0054f, -0.0091f, -0.0138f, -0.0145f, -0.0187f,
0.0323f, 0.0305f, 0.0259f, 0.0300f, 0.0540f, 0.0614f, 0.0495f, 0.0590f,
-0.0511f, -0.0603f, -0.0478f, -0.0524f, -0.0227f, -0.0274f, -0.0154f, -0.0255f,
-0.0572f, -0.0565f, -0.0518f, -0.0496f, 0.0116f, 0.0054f, 0.0163f, 0.0104f};
latents_std_vec = {
1.8029f, 1.7786f, 1.7868f, 1.7837f, 1.7717f, 1.7590f, 1.7610f, 1.7479f,
1.7336f, 1.7373f, 1.7340f, 1.7343f, 1.8626f, 1.8527f, 1.8629f, 1.8589f,
1.7593f, 1.7526f, 1.7556f, 1.7583f, 1.7363f, 1.7400f, 1.7355f, 1.7394f,
1.7342f, 1.7246f, 1.7392f, 1.7304f, 1.7551f, 1.7513f, 1.7559f, 1.7488f,
1.8449f, 1.8454f, 1.8550f, 1.8535f, 1.8240f, 1.7813f, 1.7854f, 1.7945f,
1.8047f, 1.7876f, 1.7695f, 1.7676f, 1.7782f, 1.7667f, 1.7925f, 1.7848f,
1.7579f, 1.7407f, 1.7483f, 1.7368f, 1.7961f, 1.7998f, 1.7920f, 1.7925f,
1.7780f, 1.7747f, 1.7727f, 1.7749f, 1.7526f, 1.7447f, 1.7657f, 1.7495f,
1.7775f, 1.7720f, 1.7813f, 1.7813f, 1.8162f, 1.8013f, 1.8023f, 1.8033f,
1.7527f, 1.7331f, 1.7563f, 1.7482f, 1.7610f, 1.7507f, 1.7681f, 1.7613f,
1.7665f, 1.7545f, 1.7828f, 1.7726f, 1.7896f, 1.7999f, 1.7864f, 1.7760f,
1.7613f, 1.7625f, 1.7560f, 1.7577f, 1.7783f, 1.7671f, 1.7810f, 1.7799f,
1.7201f, 1.7068f, 1.7265f, 1.7091f, 1.7793f, 1.7578f, 1.7502f, 1.7455f,
1.7587f, 1.7500f, 1.7525f, 1.7362f, 1.7616f, 1.7572f, 1.7444f, 1.7430f,
1.7509f, 1.7610f, 1.7634f, 1.7612f, 1.7254f, 1.7135f, 1.7321f, 1.7226f,
1.7664f, 1.7624f, 1.7718f, 1.7664f, 1.7457f, 1.7441f, 1.7569f, 1.7530f};
} else {
GGML_ABORT("unknown version %d", version);
}
}
ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) {
ggml_tensor* vae_latents = ggml_dup(work_ctx, latents);
if (sd_version_is_flux2(version)) {
int channel_dim = 2;
std::vector<float> latents_mean_vec;
std::vector<float> latents_std_vec;
get_latents_mean_std_vec(latents, channel_dim, latents_mean_vec, latents_std_vec);
float mean;
float std_;
for (int i = 0; i < latents->ne[3]; i++) {
if (channel_dim == 3) {
mean = latents_mean_vec[i];
std_ = latents_std_vec[i];
}
for (int j = 0; j < latents->ne[2]; j++) {
if (channel_dim == 2) {
mean = latents_mean_vec[j];
std_ = latents_std_vec[j];
}
for (int k = 0; k < latents->ne[1]; k++) {
for (int l = 0; l < latents->ne[0]; l++) {
float value = ggml_ext_tensor_get_f32(latents, l, k, j, i);
value = value * std_ / scale_factor + mean;
ggml_ext_tensor_set_f32(vae_latents, value, l, k, j, i);
}
}
}
}
} else {
ggml_ext_tensor_iter(latents, [&](ggml_tensor* latents, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = ggml_ext_tensor_get_f32(latents, i0, i1, i2, i3);
value = (value / scale_factor) + shift_factor;
ggml_ext_tensor_set_f32(vae_latents, value, i0, i1, i2, i3);
});
}
return vae_latents;
}
ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) {
ggml_tensor* diffusion_latents = ggml_dup(work_ctx, latents);
if (sd_version_is_flux2(version)) {
int channel_dim = 2;
std::vector<float> latents_mean_vec;
std::vector<float> latents_std_vec;
get_latents_mean_std_vec(latents, channel_dim, latents_mean_vec, latents_std_vec);
float mean;
float std_;
for (int i = 0; i < latents->ne[3]; i++) {
if (channel_dim == 3) {
mean = latents_mean_vec[i];
std_ = latents_std_vec[i];
}
for (int j = 0; j < latents->ne[2]; j++) {
if (channel_dim == 2) {
mean = latents_mean_vec[j];
std_ = latents_std_vec[j];
}
for (int k = 0; k < latents->ne[1]; k++) {
for (int l = 0; l < latents->ne[0]; l++) {
float value = ggml_ext_tensor_get_f32(latents, l, k, j, i);
value = (value - mean) * scale_factor / std_;
ggml_ext_tensor_set_f32(diffusion_latents, value, l, k, j, i);
}
}
}
}
} else {
ggml_ext_tensor_iter(latents, [&](ggml_tensor* latents, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = ggml_ext_tensor_get_f32(latents, i0, i1, i2, i3);
value = (value - shift_factor) * scale_factor;
ggml_ext_tensor_set_f32(diffusion_latents, value, i0, i1, i2, i3);
});
}
return diffusion_latents;
}
int get_encoder_output_channels(int input_channels) {
return ae.get_encoder_output_channels();
}
void test() {
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
params.mem_buffer = nullptr;
params.no_alloc = false;
struct ggml_context* work_ctx = ggml_init(params);
GGML_ASSERT(work_ctx != nullptr);
{
// CPU, x{1, 3, 64, 64}: Pass
// CUDA, x{1, 3, 64, 64}: Pass, but sill get wrong result for some image, may be due to interlnal nan
// CPU, x{2, 3, 64, 64}: Wrong result
// CUDA, x{2, 3, 64, 64}: Wrong result, and different from CPU result
auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 64, 64, 3, 2);
ggml_set_f32(x, 0.5f);
print_ggml_tensor(x);
struct ggml_tensor* out = nullptr;
int64_t t0 = ggml_time_ms();
_compute(8, x, false, &out, work_ctx);
int64_t t1 = ggml_time_ms();
print_ggml_tensor(out);
LOG_DEBUG("encode test done in %lldms", t1 - t0);
}
if (false) {
// CPU, z{1, 4, 8, 8}: Pass
// CUDA, z{1, 4, 8, 8}: Pass
// CPU, z{3, 4, 8, 8}: Wrong result
// CUDA, z{3, 4, 8, 8}: Wrong result, and different from CPU result
auto z = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 4, 1);
ggml_set_f32(z, 0.5f);
print_ggml_tensor(z);
struct ggml_tensor* out = nullptr;
int64_t t0 = ggml_time_ms();
_compute(8, z, true, &out, work_ctx);
int64_t t1 = ggml_time_ms();
print_ggml_tensor(out);
LOG_DEBUG("decode test done in %lldms", t1 - t0);
}
};
};
#endif // __AUTO_ENCODER_KL_HPP__

View File

@ -603,87 +603,6 @@ inline std::vector<int> generate_scm_mask(
return mask;
}
inline std::vector<int> get_scm_preset(const std::string& preset, int total_steps) {
struct Preset {
std::vector<int> compute_bins;
std::vector<int> cache_bins;
};
Preset slow = {{8, 3, 3, 2, 1, 1}, {1, 2, 2, 2, 3}};
Preset medium = {{6, 2, 2, 2, 2, 1}, {1, 3, 3, 3, 3}};
Preset fast = {{6, 1, 1, 1, 1, 1}, {1, 3, 4, 5, 4}};
Preset ultra = {{4, 1, 1, 1, 1}, {2, 5, 6, 7}};
Preset* p = nullptr;
if (preset == "slow" || preset == "s" || preset == "S")
p = &slow;
else if (preset == "medium" || preset == "m" || preset == "M")
p = &medium;
else if (preset == "fast" || preset == "f" || preset == "F")
p = &fast;
else if (preset == "ultra" || preset == "u" || preset == "U")
p = &ultra;
else
return {};
if (total_steps != 28 && total_steps > 0) {
float scale = static_cast<float>(total_steps) / 28.0f;
std::vector<int> scaled_compute, scaled_cache;
for (int v : p->compute_bins) {
scaled_compute.push_back(std::max(1, static_cast<int>(v * scale + 0.5f)));
}
for (int v : p->cache_bins) {
scaled_cache.push_back(std::max(1, static_cast<int>(v * scale + 0.5f)));
}
return generate_scm_mask(scaled_compute, scaled_cache, total_steps);
}
return generate_scm_mask(p->compute_bins, p->cache_bins, total_steps);
}
inline float get_preset_threshold(const std::string& preset) {
if (preset == "slow" || preset == "s" || preset == "S")
return 0.20f;
if (preset == "medium" || preset == "m" || preset == "M")
return 0.25f;
if (preset == "fast" || preset == "f" || preset == "F")
return 0.30f;
if (preset == "ultra" || preset == "u" || preset == "U")
return 0.34f;
return 0.08f;
}
inline int get_preset_warmup(const std::string& preset) {
if (preset == "slow" || preset == "s" || preset == "S")
return 8;
if (preset == "medium" || preset == "m" || preset == "M")
return 6;
if (preset == "fast" || preset == "f" || preset == "F")
return 6;
if (preset == "ultra" || preset == "u" || preset == "U")
return 4;
return 8;
}
inline int get_preset_Fn(const std::string& preset) {
if (preset == "slow" || preset == "s" || preset == "S")
return 8;
if (preset == "medium" || preset == "m" || preset == "M")
return 8;
if (preset == "fast" || preset == "f" || preset == "F")
return 6;
if (preset == "ultra" || preset == "u" || preset == "U")
return 4;
return 8;
}
inline int get_preset_Bn(const std::string& preset) {
(void)preset;
return 0;
}
inline void parse_dbcache_options(const std::string& opts, DBCacheConfig& cfg) {
if (opts.empty())
return;

View File

@ -377,6 +377,12 @@ __STATIC_INLINE__ void copy_ggml_tensor(struct ggml_tensor* dst, struct ggml_ten
ggml_free(ctx);
}
__STATIC_INLINE__ ggml_tensor* ggml_ext_dup_and_cpy_tensor(ggml_context* ctx, ggml_tensor* src) {
ggml_tensor* dup = ggml_dup_tensor(ctx, src);
copy_ggml_tensor(dup, src);
return dup;
}
__STATIC_INLINE__ float sigmoid(float x) {
return 1 / (1.0f + expf(-x));
}
@ -637,7 +643,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_tensor_concat(struct ggml_context
}
// convert values from [0, 1] to [-1, 1]
__STATIC_INLINE__ void process_vae_input_tensor(struct ggml_tensor* src) {
__STATIC_INLINE__ void scale_to_minus1_1(struct ggml_tensor* src) {
int64_t nelements = ggml_nelements(src);
float* data = (float*)src->data;
for (int i = 0; i < nelements; i++) {
@ -647,7 +653,7 @@ __STATIC_INLINE__ void process_vae_input_tensor(struct ggml_tensor* src) {
}
// convert values from [-1, 1] to [0, 1]
__STATIC_INLINE__ void process_vae_output_tensor(struct ggml_tensor* src) {
__STATIC_INLINE__ void scale_to_0_1(struct ggml_tensor* src) {
int64_t nelements = ggml_nelements(src);
float* data = (float*)src->data;
for (int i = 0; i < nelements; i++) {
@ -834,7 +840,8 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
const float tile_overlap_factor,
const bool circular_x,
const bool circular_y,
on_tile_process on_processing) {
on_tile_process on_processing,
bool slient = false) {
output = ggml_set_f32(output, 0);
int input_width = (int)input->ne[0];
@ -864,8 +871,10 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
float tile_overlap_factor_y;
sd_tiling_calc_tiles(num_tiles_y, tile_overlap_factor_y, small_height, p_tile_size_y, tile_overlap_factor, circular_y);
LOG_DEBUG("num tiles : %d, %d ", num_tiles_x, num_tiles_y);
LOG_DEBUG("optimal overlap : %f, %f (targeting %f)", tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor);
if (!slient) {
LOG_DEBUG("num tiles : %d, %d ", num_tiles_x, num_tiles_y);
LOG_DEBUG("optimal overlap : %f, %f (targeting %f)", tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor);
}
int tile_overlap_x = (int32_t)(p_tile_size_x * tile_overlap_factor_x);
int non_tile_overlap_x = p_tile_size_x - tile_overlap_x;
@ -896,7 +905,9 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
params.mem_buffer = nullptr;
params.no_alloc = false;
LOG_DEBUG("tile work buffer size: %.2f MB", params.mem_size / 1024.f / 1024.f);
if (!slient) {
LOG_DEBUG("tile work buffer size: %.2f MB", params.mem_size / 1024.f / 1024.f);
}
// draft context
struct ggml_context* tiles_ctx = ggml_init(params);
@ -909,8 +920,10 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size_x, input_tile_size_y, input->ne[2], input->ne[3]);
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size_x, output_tile_size_y, output->ne[2], output->ne[3]);
int num_tiles = num_tiles_x * num_tiles_y;
LOG_DEBUG("processing %i tiles", num_tiles);
pretty_progress(0, num_tiles, 0.0f);
if (!slient) {
LOG_DEBUG("processing %i tiles", num_tiles);
pretty_progress(0, num_tiles, 0.0f);
}
int tile_count = 1;
bool last_y = false, last_x = false;
float last_time = 0.0f;
@ -960,8 +973,10 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
}
last_x = false;
}
if (tile_count < num_tiles) {
pretty_progress(num_tiles, num_tiles, last_time);
if (!slient) {
if (tile_count < num_tiles) {
pretty_progress(num_tiles, num_tiles, last_time);
}
}
ggml_free(tiles_ctx);
}

View File

@ -1104,10 +1104,12 @@ SDVersion ModelLoader::get_sd_version() {
tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) {
has_middle_block_1 = true;
}
if (tensor_storage.name.find("model.diffusion_model.output_blocks.3.1.transformer_blocks.1") != std::string::npos) {
if (tensor_storage.name.find("model.diffusion_model.output_blocks.3.1.transformer_blocks.1") != std::string::npos ||
tensor_storage.name.find("unet.up_blocks.1.attentions.0.transformer_blocks.1") != std::string::npos) {
has_output_block_311 = true;
}
if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos) {
if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos ||
tensor_storage.name.find("unet.up_blocks.2.attentions.1") != std::string::npos) {
has_output_block_71 = true;
}
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||

View File

@ -1120,7 +1120,11 @@ std::string convert_tensor_name(std::string name, SDVersion version) {
for (const auto& prefix : first_stage_model_prefix_vec) {
if (starts_with(name, prefix)) {
name = convert_first_stage_model_name(name.substr(prefix.size()), prefix);
name = prefix + name;
if (version == VERSION_SDXS) {
name = "tae." + name;
} else {
name = prefix + name;
}
break;
}
}

View File

@ -7,6 +7,7 @@
#include "stable-diffusion.h"
#include "util.h"
#include "auto_encoder_kl.hpp"
#include "cache_dit.hpp"
#include "conditioner.hpp"
#include "control.hpp"
@ -90,12 +91,17 @@ void calculate_alphas_cumprod(float* alphas_cumprod,
}
}
void suppress_pp(int step, int steps, float time, void* data) {
(void)step;
(void)steps;
(void)time;
(void)data;
return;
static float get_cache_reuse_threshold(const sd_cache_params_t& params) {
float reuse_threshold = params.reuse_threshold;
if (reuse_threshold == INFINITY) {
if (params.mode == SD_CACHE_EASYCACHE) {
reuse_threshold = 0.2;
}
else if (params.mode == SD_CACHE_UCACHE) {
reuse_threshold = 1.0;
}
}
return std::max(0.0f, reuse_threshold);
}
/*=============================================== StableDiffusionGGML ================================================*/
@ -118,8 +124,6 @@ public:
std::shared_ptr<RNG> rng = std::make_shared<PhiloxRNG>();
std::shared_ptr<RNG> sampler_rng = nullptr;
int n_threads = -1;
float scale_factor = 0.18215f;
float shift_factor = 0.f;
float default_flow_shift = INFINITY;
std::shared_ptr<Conditioner> cond_stage_model;
@ -127,7 +131,7 @@ public:
std::shared_ptr<DiffusionModel> diffusion_model;
std::shared_ptr<DiffusionModel> high_noise_diffusion_model;
std::shared_ptr<VAE> first_stage_model;
std::shared_ptr<TinyAutoEncoder> tae_first_stage;
std::shared_ptr<VAE> preview_vae;
std::shared_ptr<ControlNet> control_net;
std::shared_ptr<PhotoMakerIDEncoder> pmid_model;
std::shared_ptr<LoraModel> pmid_lora;
@ -138,7 +142,6 @@ public:
bool apply_lora_immediately = false;
std::string taesd_path;
bool use_tiny_autoencoder = false;
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0, 0};
bool offload_params_to_cpu = false;
bool use_pmid = false;
@ -239,10 +242,10 @@ public:
n_threads = sd_ctx_params->n_threads;
vae_decode_only = sd_ctx_params->vae_decode_only;
free_params_immediately = sd_ctx_params->free_params_immediately;
taesd_path = SAFE_STR(sd_ctx_params->taesd_path);
use_tiny_autoencoder = taesd_path.size() > 0;
offload_params_to_cpu = sd_ctx_params->offload_params_to_cpu;
bool use_tae = false;
rng = get_rng(sd_ctx_params->rng_type);
if (sd_ctx_params->sampler_rng_type != RNG_TYPE_COUNT && sd_ctx_params->sampler_rng_type != sd_ctx_params->rng_type) {
sampler_rng = get_rng(sd_ctx_params->sampler_rng_type);
@ -332,6 +335,14 @@ public:
}
}
if (strlen(SAFE_STR(sd_ctx_params->taesd_path)) > 0) {
LOG_INFO("loading tae from '%s'", sd_ctx_params->taesd_path);
if (!model_loader.init_from_file(sd_ctx_params->taesd_path, "tae.")) {
LOG_WARN("loading tae from '%s' failed", sd_ctx_params->taesd_path);
}
use_tae = true;
}
model_loader.convert_tensors_name();
version = model_loader.get_sd_version();
@ -400,22 +411,6 @@ public:
apply_lora_immediately = false;
}
if (sd_version_is_sdxl(version)) {
scale_factor = 0.13025f;
} else if (sd_version_is_sd3(version)) {
scale_factor = 1.5305f;
shift_factor = 0.0609f;
} else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) {
scale_factor = 0.3611f;
shift_factor = 0.1159f;
} else if (sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) ||
sd_version_is_anima(version) ||
sd_version_is_flux2(version)) {
scale_factor = 1.0f;
shift_factor = 0.f;
}
if (sd_version_is_control(version)) {
// Might need vae encode for control cond
vae_decode_only = false;
@ -424,6 +419,7 @@ public:
bool tae_preview_only = sd_ctx_params->tae_preview_only;
if (version == VERSION_SDXS) {
tae_preview_only = false;
use_tae = true;
}
if (sd_ctx_params->circular_x || sd_ctx_params->circular_y) {
@ -610,31 +606,46 @@ public:
vae_backend = backend;
}
if (!(use_tiny_autoencoder || version == VERSION_SDXS) || tae_preview_only) {
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
offload_params_to_cpu,
tensor_storage_map,
"first_stage_model",
vae_decode_only,
version);
first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "first_stage_model");
} else if (version == VERSION_CHROMA_RADIANCE) {
first_stage_model = std::make_shared<FakeVAE>(vae_backend,
offload_params_to_cpu);
auto create_tae = [&]() -> std::shared_ptr<VAE> {
if (sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) ||
sd_version_is_anima(version)) {
return std::make_shared<TinyVideoAutoEncoder>(vae_backend,
offload_params_to_cpu,
tensor_storage_map,
"decoder",
vae_decode_only,
version);
} else {
first_stage_model = std::make_shared<AutoEncoderKL>(vae_backend,
auto model = std::make_shared<TinyImageAutoEncoder>(vae_backend,
offload_params_to_cpu,
tensor_storage_map,
"first_stage_model",
"decoder.layers",
vae_decode_only,
false,
version);
if (sd_ctx_params->vae_conv_direct) {
LOG_INFO("Using Conv2d direct in the vae model");
first_stage_model->set_conv2d_direct_enabled(true);
}
return model;
}
};
auto create_vae = [&]() -> std::shared_ptr<VAE> {
if (sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) ||
sd_version_is_anima(version)) {
return std::make_shared<WAN::WanVAERunner>(vae_backend,
offload_params_to_cpu,
tensor_storage_map,
"first_stage_model",
vae_decode_only,
version);
} else {
auto model = std::make_shared<AutoEncoderKL>(vae_backend,
offload_params_to_cpu,
tensor_storage_map,
"first_stage_model",
vae_decode_only,
false,
version);
if (sd_version_is_sdxl(version) &&
(strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale || external_vae_is_invalid)) {
float vae_conv_2d_scale = 1.f / 32.f;
@ -642,35 +653,40 @@ public:
"No valid VAE specified with --vae or --force-sdxl-vae-conv-scale flag set, "
"using Conv2D scale %.3f",
vae_conv_2d_scale);
first_stage_model->set_conv2d_scale(vae_conv_2d_scale);
model->set_conv2d_scale(vae_conv_2d_scale);
}
first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "first_stage_model");
return model;
}
};
if (version == VERSION_CHROMA_RADIANCE) {
LOG_INFO("using FakeVAE");
first_stage_model = std::make_shared<FakeVAE>(version,
vae_backend,
offload_params_to_cpu);
} else if (use_tae && !tae_preview_only) {
LOG_INFO("using TAE for encoding / decoding");
first_stage_model = create_tae();
first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "tae");
} else {
LOG_INFO("using VAE for encoding / decoding");
first_stage_model = create_vae();
first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "first_stage_model");
if (use_tae && tae_preview_only) {
LOG_INFO("using TAE for preview");
preview_vae = create_tae();
preview_vae->alloc_params_buffer();
preview_vae->get_param_tensors(tensors, "tae");
}
}
if (use_tiny_autoencoder || version == VERSION_SDXS) {
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
tae_first_stage = std::make_shared<TinyVideoAutoEncoder>(vae_backend,
offload_params_to_cpu,
tensor_storage_map,
"decoder",
vae_decode_only,
version);
} else {
tae_first_stage = std::make_shared<TinyImageAutoEncoder>(vae_backend,
offload_params_to_cpu,
tensor_storage_map,
"decoder.layers",
vae_decode_only,
version);
if (version == VERSION_SDXS) {
tae_first_stage->alloc_params_buffer();
tae_first_stage->get_param_tensors(tensors, "first_stage_model");
}
}
if (sd_ctx_params->vae_conv_direct) {
LOG_INFO("Using Conv2d direct in the tae model");
tae_first_stage->set_conv2d_direct_enabled(true);
if (sd_ctx_params->vae_conv_direct) {
LOG_INFO("Using Conv2d direct in the vae model");
first_stage_model->set_conv2d_direct_enabled(true);
if (preview_vae) {
preview_vae->set_conv2d_direct_enabled(true);
}
}
@ -743,8 +759,8 @@ public:
if (first_stage_model) {
first_stage_model->set_flash_attention_enabled(true);
}
if (tae_first_stage) {
tae_first_stage->set_flash_attention_enabled(true);
if (preview_vae) {
preview_vae->set_flash_attention_enabled(true);
}
}
@ -782,7 +798,7 @@ public:
std::set<std::string> ignore_tensors;
tensors["alphas_cumprod"] = alphas_cumprod_tensor;
if (use_tiny_autoencoder) {
if (use_tae && !tae_preview_only) {
ignore_tensors.insert("first_stage_model.");
}
if (use_pmid) {
@ -796,6 +812,7 @@ public:
ignore_tensors.insert("first_stage_model.encoder");
ignore_tensors.insert("first_stage_model.conv1");
ignore_tensors.insert("first_stage_model.quant");
ignore_tensors.insert("tae.encoder");
ignore_tensors.insert("text_encoders.llm.visual.");
}
if (version == VERSION_OVIS_IMAGE) {
@ -822,15 +839,9 @@ public:
unet_params_mem_size += high_noise_diffusion_model->get_params_buffer_size();
}
size_t vae_params_mem_size = 0;
if (!(use_tiny_autoencoder || version == VERSION_SDXS) || tae_preview_only) {
vae_params_mem_size = first_stage_model->get_params_buffer_size();
}
if (use_tiny_autoencoder || version == VERSION_SDXS) {
if (use_tiny_autoencoder && !tae_first_stage->load_from_file(taesd_path, n_threads)) {
return false;
}
use_tiny_autoencoder = true; // now the processing is identical for VERSION_SDXS
vae_params_mem_size = tae_first_stage->get_params_buffer_size();
vae_params_mem_size = first_stage_model->get_params_buffer_size();
if (preview_vae) {
vae_params_mem_size += preview_vae->get_params_buffer_size();
}
size_t control_net_params_mem_size = 0;
if (control_net) {
@ -983,7 +994,6 @@ public:
}
ggml_free(ctx);
use_tiny_autoencoder = use_tiny_autoencoder && !tae_preview_only;
return true;
}
@ -1422,8 +1432,7 @@ public:
ggml_ext_tensor_scale_inplace(noise, augmentation_level);
ggml_ext_tensor_add_inplace(init_img, noise);
}
ggml_tensor* moments = vae_encode(work_ctx, init_img);
c_concat = get_first_stage_encoding(work_ctx, moments);
c_concat = encode_first_stage(work_ctx, init_img);
}
}
@ -1475,14 +1484,6 @@ public:
}
}
void silent_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
sd_progress_cb_t cb = sd_get_progress_callback();
void* cbd = sd_get_progress_callback_data();
sd_set_progress_callback((sd_progress_cb_t)suppress_pp, nullptr);
sd_tiling(input, output, scale, tile_size, tile_overlap_factor, circular_x, circular_y, on_processing);
sd_set_progress_callback(cb, cbd);
}
void preview_image(ggml_context* work_ctx,
int step,
struct ggml_tensor* latents,
@ -1575,37 +1576,14 @@ public:
free(data);
free(images);
} else {
if (preview_mode == PREVIEW_VAE) {
process_latent_out(latents);
if (vae_tiling_params.enabled) {
// split latent in 32x32 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
return first_stage_model->compute(n_threads, in, true, &out, nullptr);
};
silent_tiling(latents, result, get_vae_scale_factor(), 32, 0.5f, on_tiling);
if (preview_mode == PREVIEW_VAE || preview_mode == PREVIEW_TAE) {
if (preview_vae) {
latents = preview_vae->diffusion_to_vae_latents(work_ctx, latents);
result = preview_vae->decode(n_threads, work_ctx, latents, vae_tiling_params, false, circular_x, circular_y, result, true);
} else {
first_stage_model->compute(n_threads, latents, true, &result, work_ctx);
latents = first_stage_model->diffusion_to_vae_latents(work_ctx, latents);
result = first_stage_model->decode(n_threads, work_ctx, latents, vae_tiling_params, false, circular_x, circular_y, result, true);
}
first_stage_model->free_compute_buffer();
process_vae_output_tensor(result);
process_latent_in(latents);
} else if (preview_mode == PREVIEW_TAE) {
if (tae_first_stage == nullptr) {
LOG_WARN("TAE not found for preview");
return;
}
if (vae_tiling_params.enabled) {
// split latent in 64x64 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
return tae_first_stage->compute(n_threads, in, true, &out, nullptr);
};
silent_tiling(latents, result, get_vae_scale_factor(), 64, 0.5f, on_tiling);
} else {
tae_first_stage->compute(n_threads, latents, true, &result, work_ctx);
}
tae_first_stage->free_compute_buffer();
} else {
return;
}
@ -1715,7 +1693,7 @@ public:
} else {
EasyCacheConfig easycache_config;
easycache_config.enabled = true;
easycache_config.reuse_threshold = std::max(0.0f, cache_params->reuse_threshold);
easycache_config.reuse_threshold = get_cache_reuse_threshold(*cache_params);
easycache_config.start_percent = cache_params->start_percent;
easycache_config.end_percent = cache_params->end_percent;
easycache_state.init(easycache_config, denoiser.get());
@ -1736,7 +1714,7 @@ public:
} else {
UCacheConfig ucache_config;
ucache_config.enabled = true;
ucache_config.reuse_threshold = std::max(0.0f, cache_params->reuse_threshold);
ucache_config.reuse_threshold = get_cache_reuse_threshold(*cache_params);
ucache_config.start_percent = cache_params->start_percent;
ucache_config.end_percent = cache_params->end_percent;
ucache_config.error_decay_rate = std::max(0.0f, std::min(1.0f, cache_params->error_decay_rate));
@ -1797,9 +1775,9 @@ public:
}
}
} else if (cache_params->mode == SD_CACHE_SPECTRUM) {
bool spectrum_supported = sd_version_is_unet(version);
bool spectrum_supported = sd_version_is_unet(version) || sd_version_is_dit(version);
if (!spectrum_supported) {
LOG_WARN("Spectrum requested but not supported for this model type (only UNET models)");
LOG_WARN("Spectrum requested but not supported for this model type (only UNET and DiT models)");
} else {
SpectrumConfig spectrum_config;
spectrum_config.w = cache_params->spectrum_w;
@ -1829,8 +1807,7 @@ public:
}
size_t steps = sigmas.size() - 1;
struct ggml_tensor* x = ggml_dup_tensor(work_ctx, init_latent);
copy_ggml_tensor(x, init_latent);
struct ggml_tensor* x = ggml_ext_dup_and_cpy_tensor(work_ctx, init_latent);
if (noise) {
x = denoiser->noise_scaling(sigmas[0], noise, x);
@ -2351,15 +2328,7 @@ public:
}
int get_vae_scale_factor() {
int vae_scale_factor = 8;
if (version == VERSION_WAN2_2_TI2V) {
vae_scale_factor = 16;
} else if (sd_version_is_flux2(version)) {
vae_scale_factor = 16;
} else if (version == VERSION_CHROMA_RADIANCE) {
vae_scale_factor = 1;
}
return vae_scale_factor;
return first_stage_model->get_scale_factor();
}
int get_diffusion_model_down_factor() {
@ -2414,383 +2383,28 @@ public:
} else {
init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
}
ggml_set_f32(init_latent, shift_factor);
ggml_set_f32(init_latent, 0.f);
return init_latent;
}
void get_latents_mean_std_vec(ggml_tensor* latent, int channel_dim, std::vector<float>& latents_mean_vec, std::vector<float>& latents_std_vec) {
GGML_ASSERT(latent->ne[channel_dim] == 16 || latent->ne[channel_dim] == 48 || latent->ne[channel_dim] == 128);
if (latent->ne[channel_dim] == 16) {
latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f,
0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f};
latents_std_vec = {2.8184f, 1.4541f, 2.3275f, 2.6558f, 1.2196f, 1.7708f, 2.6052f, 2.0743f,
3.2687f, 2.1526f, 2.8652f, 1.5579f, 1.6382f, 1.1253f, 2.8251f, 1.9160f};
} else if (latent->ne[channel_dim] == 48) {
latents_mean_vec = {-0.2289f, -0.0052f, -0.1323f, -0.2339f, -0.2799f, 0.0174f, 0.1838f, 0.1557f,
-0.1382f, 0.0542f, 0.2813f, 0.0891f, 0.1570f, -0.0098f, 0.0375f, -0.1825f,
-0.2246f, -0.1207f, -0.0698f, 0.5109f, 0.2665f, -0.2108f, -0.2158f, 0.2502f,
-0.2055f, -0.0322f, 0.1109f, 0.1567f, -0.0729f, 0.0899f, -0.2799f, -0.1230f,
-0.0313f, -0.1649f, 0.0117f, 0.0723f, -0.2839f, -0.2083f, -0.0520f, 0.3748f,
0.0152f, 0.1957f, 0.1433f, -0.2944f, 0.3573f, -0.0548f, -0.1681f, -0.0667f};
latents_std_vec = {
0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f,
0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f,
0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f,
0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f,
0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f,
0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f};
} else if (latent->ne[channel_dim] == 128) {
// flux2
latents_mean_vec = {-0.0676f, -0.0715f, -0.0753f, -0.0745f, 0.0223f, 0.0180f, 0.0142f, 0.0184f,
-0.0001f, -0.0063f, -0.0002f, -0.0031f, -0.0272f, -0.0281f, -0.0276f, -0.0290f,
-0.0769f, -0.0672f, -0.0902f, -0.0892f, 0.0168f, 0.0152f, 0.0079f, 0.0086f,
0.0083f, 0.0015f, 0.0003f, -0.0043f, -0.0439f, -0.0419f, -0.0438f, -0.0431f,
-0.0102f, -0.0132f, -0.0066f, -0.0048f, -0.0311f, -0.0306f, -0.0279f, -0.0180f,
0.0030f, 0.0015f, 0.0126f, 0.0145f, 0.0347f, 0.0338f, 0.0337f, 0.0283f,
0.0020f, 0.0047f, 0.0047f, 0.0050f, 0.0123f, 0.0081f, 0.0081f, 0.0146f,
0.0681f, 0.0679f, 0.0767f, 0.0732f, -0.0462f, -0.0474f, -0.0392f, -0.0511f,
-0.0528f, -0.0477f, -0.0470f, -0.0517f, -0.0317f, -0.0316f, -0.0345f, -0.0283f,
0.0510f, 0.0445f, 0.0578f, 0.0458f, -0.0412f, -0.0458f, -0.0487f, -0.0467f,
-0.0088f, -0.0106f, -0.0088f, -0.0046f, -0.0376f, -0.0432f, -0.0436f, -0.0499f,
0.0118f, 0.0166f, 0.0203f, 0.0279f, 0.0113f, 0.0129f, 0.0016f, 0.0072f,
-0.0118f, -0.0018f, -0.0141f, -0.0054f, -0.0091f, -0.0138f, -0.0145f, -0.0187f,
0.0323f, 0.0305f, 0.0259f, 0.0300f, 0.0540f, 0.0614f, 0.0495f, 0.0590f,
-0.0511f, -0.0603f, -0.0478f, -0.0524f, -0.0227f, -0.0274f, -0.0154f, -0.0255f,
-0.0572f, -0.0565f, -0.0518f, -0.0496f, 0.0116f, 0.0054f, 0.0163f, 0.0104f};
latents_std_vec = {
1.8029f, 1.7786f, 1.7868f, 1.7837f, 1.7717f, 1.7590f, 1.7610f, 1.7479f,
1.7336f, 1.7373f, 1.7340f, 1.7343f, 1.8626f, 1.8527f, 1.8629f, 1.8589f,
1.7593f, 1.7526f, 1.7556f, 1.7583f, 1.7363f, 1.7400f, 1.7355f, 1.7394f,
1.7342f, 1.7246f, 1.7392f, 1.7304f, 1.7551f, 1.7513f, 1.7559f, 1.7488f,
1.8449f, 1.8454f, 1.8550f, 1.8535f, 1.8240f, 1.7813f, 1.7854f, 1.7945f,
1.8047f, 1.7876f, 1.7695f, 1.7676f, 1.7782f, 1.7667f, 1.7925f, 1.7848f,
1.7579f, 1.7407f, 1.7483f, 1.7368f, 1.7961f, 1.7998f, 1.7920f, 1.7925f,
1.7780f, 1.7747f, 1.7727f, 1.7749f, 1.7526f, 1.7447f, 1.7657f, 1.7495f,
1.7775f, 1.7720f, 1.7813f, 1.7813f, 1.8162f, 1.8013f, 1.8023f, 1.8033f,
1.7527f, 1.7331f, 1.7563f, 1.7482f, 1.7610f, 1.7507f, 1.7681f, 1.7613f,
1.7665f, 1.7545f, 1.7828f, 1.7726f, 1.7896f, 1.7999f, 1.7864f, 1.7760f,
1.7613f, 1.7625f, 1.7560f, 1.7577f, 1.7783f, 1.7671f, 1.7810f, 1.7799f,
1.7201f, 1.7068f, 1.7265f, 1.7091f, 1.7793f, 1.7578f, 1.7502f, 1.7455f,
1.7587f, 1.7500f, 1.7525f, 1.7362f, 1.7616f, 1.7572f, 1.7444f, 1.7430f,
1.7509f, 1.7610f, 1.7634f, 1.7612f, 1.7254f, 1.7135f, 1.7321f, 1.7226f,
1.7664f, 1.7624f, 1.7718f, 1.7664f, 1.7457f, 1.7441f, 1.7569f, 1.7530f};
}
}
void process_latent_in(ggml_tensor* latent) {
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version) || sd_version_is_flux2(version)) {
int channel_dim = sd_version_is_flux2(version) ? 2 : 3;
std::vector<float> latents_mean_vec;
std::vector<float> latents_std_vec;
get_latents_mean_std_vec(latent, channel_dim, latents_mean_vec, latents_std_vec);
float mean;
float std_;
for (int i = 0; i < latent->ne[3]; i++) {
if (channel_dim == 3) {
mean = latents_mean_vec[i];
std_ = latents_std_vec[i];
}
for (int j = 0; j < latent->ne[2]; j++) {
if (channel_dim == 2) {
mean = latents_mean_vec[i];
std_ = latents_std_vec[i];
}
for (int k = 0; k < latent->ne[1]; k++) {
for (int l = 0; l < latent->ne[0]; l++) {
float value = ggml_ext_tensor_get_f32(latent, l, k, j, i);
value = (value - mean) * scale_factor / std_;
ggml_ext_tensor_set_f32(latent, value, l, k, j, i);
}
}
}
}
} else if (version == VERSION_CHROMA_RADIANCE) {
// pass
} else {
ggml_ext_tensor_iter(latent, [&](ggml_tensor* latent, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = ggml_ext_tensor_get_f32(latent, i0, i1, i2, i3);
value = (value - shift_factor) * scale_factor;
ggml_ext_tensor_set_f32(latent, value, i0, i1, i2, i3);
});
}
}
void process_latent_out(ggml_tensor* latent) {
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version) || sd_version_is_flux2(version)) {
int channel_dim = sd_version_is_flux2(version) ? 2 : 3;
std::vector<float> latents_mean_vec;
std::vector<float> latents_std_vec;
get_latents_mean_std_vec(latent, channel_dim, latents_mean_vec, latents_std_vec);
float mean;
float std_;
for (int i = 0; i < latent->ne[3]; i++) {
if (channel_dim == 3) {
mean = latents_mean_vec[i];
std_ = latents_std_vec[i];
}
for (int j = 0; j < latent->ne[2]; j++) {
if (channel_dim == 2) {
mean = latents_mean_vec[i];
std_ = latents_std_vec[i];
}
for (int k = 0; k < latent->ne[1]; k++) {
for (int l = 0; l < latent->ne[0]; l++) {
float value = ggml_ext_tensor_get_f32(latent, l, k, j, i);
value = value * std_ / scale_factor + mean;
ggml_ext_tensor_set_f32(latent, value, l, k, j, i);
}
}
}
}
} else if (version == VERSION_CHROMA_RADIANCE) {
// pass
} else {
ggml_ext_tensor_iter(latent, [&](ggml_tensor* latent, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = ggml_ext_tensor_get_f32(latent, i0, i1, i2, i3);
value = (value / scale_factor) + shift_factor;
ggml_ext_tensor_set_f32(latent, value, i0, i1, i2, i3);
});
}
}
void get_tile_sizes(int& tile_size_x,
int& tile_size_y,
float& tile_overlap,
const sd_tiling_params_t& params,
int64_t latent_x,
int64_t latent_y,
float encoding_factor = 1.0f) {
tile_overlap = std::max(std::min(params.target_overlap, 0.5f), 0.0f);
auto get_tile_size = [&](int requested_size, float factor, int64_t latent_size) {
const int default_tile_size = 32;
const int min_tile_dimension = 4;
int tile_size = default_tile_size;
// factor <= 1 means simple fraction of the latent dimension
// factor > 1 means number of tiles across that dimension
if (factor > 0.f) {
if (factor > 1.0)
factor = 1 / (factor - factor * tile_overlap + tile_overlap);
tile_size = static_cast<int>(std::round(latent_size * factor));
} else if (requested_size >= min_tile_dimension) {
tile_size = requested_size;
}
tile_size = static_cast<int>(tile_size * encoding_factor);
return std::max(std::min(tile_size, static_cast<int>(latent_size)), min_tile_dimension);
};
tile_size_x = get_tile_size(params.tile_size_x, params.rel_size_x, latent_x);
tile_size_y = get_tile_size(params.tile_size_y, params.rel_size_y, latent_y);
}
ggml_tensor* vae_encode(ggml_context* work_ctx, ggml_tensor* x) {
int64_t t0 = ggml_time_ms();
ggml_tensor* result = nullptr;
const int vae_scale_factor = get_vae_scale_factor();
int64_t W = x->ne[0] / vae_scale_factor;
int64_t H = x->ne[1] / vae_scale_factor;
int64_t C = get_latent_channel();
if (vae_tiling_params.enabled) {
// TODO wan2.2 vae support?
int64_t ne2;
int64_t ne3;
if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
ne2 = 1;
ne3 = C * x->ne[3];
} else {
int64_t out_channels = C;
bool encode_outputs_mu = use_tiny_autoencoder ||
sd_version_is_wan(version) ||
sd_version_is_flux2(version) ||
version == VERSION_CHROMA_RADIANCE;
if (!encode_outputs_mu) {
out_channels *= 2;
}
ne2 = out_channels;
ne3 = x->ne[3];
}
result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, ne2, ne3);
}
if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]);
}
if (!use_tiny_autoencoder) {
process_vae_input_tensor(x);
if (vae_tiling_params.enabled) {
float tile_overlap;
int tile_size_x, tile_size_y;
// multiply tile size for encode to keep the compute buffer size consistent
get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, vae_tiling_params, W, H, 1.30539f);
LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
return first_stage_model->compute(n_threads, in, false, &out, work_ctx);
};
sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, circular_x, circular_y, on_tiling);
} else {
first_stage_model->compute(n_threads, x, false, &result, work_ctx);
}
first_stage_model->free_compute_buffer();
} else {
if (vae_tiling_params.enabled) {
// split latent in 32x32 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
return tae_first_stage->compute(n_threads, in, false, &out, nullptr);
};
sd_tiling(x, result, vae_scale_factor, 64, 0.5f, circular_x, circular_y, on_tiling);
} else {
tae_first_stage->compute(n_threads, x, false, &result, work_ctx);
}
tae_first_stage->free_compute_buffer();
}
int64_t t1 = ggml_time_ms();
LOG_DEBUG("computing vae encode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
return result;
}
ggml_tensor* gaussian_latent_sample(ggml_context* work_ctx, ggml_tensor* moments) {
// ldm.modules.distributions.distributions.DiagonalGaussianDistribution.sample
ggml_tensor* latent = ggml_new_tensor_4d(work_ctx, moments->type, moments->ne[0], moments->ne[1], moments->ne[2] / 2, moments->ne[3]);
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, latent);
ggml_ext_im_set_randn_f32(noise, rng);
{
float mean = 0;
float logvar = 0;
float value = 0;
float std_ = 0;
for (int i = 0; i < latent->ne[3]; i++) {
for (int j = 0; j < latent->ne[2]; j++) {
for (int k = 0; k < latent->ne[1]; k++) {
for (int l = 0; l < latent->ne[0]; l++) {
mean = ggml_ext_tensor_get_f32(moments, l, k, j, i);
logvar = ggml_ext_tensor_get_f32(moments, l, k, j + (int)latent->ne[2], i);
logvar = std::max(-30.0f, std::min(logvar, 20.0f));
std_ = std::exp(0.5f * logvar);
value = mean + std_ * ggml_ext_tensor_get_f32(noise, l, k, j, i);
// printf("%d %d %d %d -> %f\n", i, j, k, l, value);
ggml_ext_tensor_set_f32(latent, value, l, k, j, i);
}
}
}
}
}
return latent;
}
ggml_tensor* get_first_stage_encoding(ggml_context* work_ctx, ggml_tensor* vae_output) {
ggml_tensor* latent;
if (use_tiny_autoencoder ||
sd_version_is_qwen_image(version) ||
sd_version_is_anima(version) ||
sd_version_is_wan(version) ||
sd_version_is_flux2(version) ||
version == VERSION_CHROMA_RADIANCE) {
latent = vae_output;
} else if (version == VERSION_SD1_PIX2PIX) {
latent = ggml_view_3d(work_ctx,
vae_output,
vae_output->ne[0],
vae_output->ne[1],
vae_output->ne[2] / 2,
vae_output->nb[1],
vae_output->nb[2],
0);
} else {
latent = gaussian_latent_sample(work_ctx, vae_output);
}
if (!use_tiny_autoencoder && version != VERSION_SD1_PIX2PIX) {
process_latent_in(latent);
}
if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
latent = ggml_reshape_4d(work_ctx, latent, latent->ne[0], latent->ne[1], latent->ne[3], 1);
}
return latent;
ggml_tensor* encode_to_vae_latents(ggml_context* work_ctx, ggml_tensor* x) {
ggml_tensor* vae_output = first_stage_model->encode(n_threads, work_ctx, x, vae_tiling_params, circular_x, circular_y);
ggml_tensor* latents = first_stage_model->vae_output_to_latents(work_ctx, vae_output, rng);
return latents;
}
ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x) {
ggml_tensor* vae_output = vae_encode(work_ctx, x);
return get_first_stage_encoding(work_ctx, vae_output);
ggml_tensor* latents = encode_to_vae_latents(work_ctx, x);
if (version != VERSION_SD1_PIX2PIX) {
latents = first_stage_model->vae_to_diffuison_latents(work_ctx, latents);
}
return latents;
}
ggml_tensor* decode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) {
const int vae_scale_factor = get_vae_scale_factor();
int64_t W = x->ne[0] * vae_scale_factor;
int64_t H = x->ne[1] * vae_scale_factor;
int64_t C = 3;
ggml_tensor* result = nullptr;
if (decode_video) {
int64_t T = x->ne[2];
if (sd_version_is_wan(version)) {
T = ((T - 1) * 4) + 1;
}
result = ggml_new_tensor_4d(work_ctx,
GGML_TYPE_F32,
W,
H,
T,
3);
} else {
result = ggml_new_tensor_4d(work_ctx,
GGML_TYPE_F32,
W,
H,
C,
x->ne[3]);
}
int64_t t0 = ggml_time_ms();
if (!use_tiny_autoencoder) {
if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]);
}
process_latent_out(x);
// x = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
if (vae_tiling_params.enabled) {
float tile_overlap;
int tile_size_x, tile_size_y;
get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, vae_tiling_params, x->ne[0], x->ne[1]);
LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
// split latent in 32x32 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
return first_stage_model->compute(n_threads, in, true, &out, nullptr);
};
sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, circular_x, circular_y, on_tiling);
} else {
if (!first_stage_model->compute(n_threads, x, true, &result, work_ctx)) {
LOG_ERROR("Failed to decode latetnts");
first_stage_model->free_compute_buffer();
return nullptr;
}
}
first_stage_model->free_compute_buffer();
process_vae_output_tensor(result);
} else {
if (vae_tiling_params.enabled) {
// split latent in 64x64 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
return tae_first_stage->compute(n_threads, in, true, &out);
};
sd_tiling(x, result, vae_scale_factor, 64, 0.5f, circular_x, circular_y, on_tiling);
} else {
if (!tae_first_stage->compute(n_threads, x, true, &result)) {
LOG_ERROR("Failed to decode latetnts");
tae_first_stage->free_compute_buffer();
return nullptr;
}
}
tae_first_stage->free_compute_buffer();
}
int64_t t1 = ggml_time_ms();
LOG_DEBUG("computing vae decode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
ggml_ext_tensor_clamp_inplace(result, 0.0f, 1.0f);
return result;
x = first_stage_model->diffusion_to_vae_latents(work_ctx, x);
x = first_stage_model->decode(n_threads, work_ctx, x, vae_tiling_params, decode_video, circular_x, circular_y);
return x;
}
void set_flow_shift(float flow_shift = INFINITY) {
@ -2983,7 +2597,7 @@ enum lora_apply_mode_t str_to_lora_apply_mode(const char* str) {
void sd_cache_params_init(sd_cache_params_t* cache_params) {
*cache_params = {};
cache_params->mode = SD_CACHE_DISABLED;
cache_params->reuse_threshold = 1.0f;
cache_params->reuse_threshold = INFINITY;
cache_params->start_percent = 0.15f;
cache_params->end_percent = 0.95f;
cache_params->error_decay_rate = 1.0f;
@ -3229,7 +2843,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
snprintf(buf + strlen(buf), 4096 - strlen(buf),
"cache: %s (threshold=%.3f, start=%.2f, end=%.2f)\n",
cache_mode_str,
sd_img_gen_params->cache.reuse_threshold,
get_cache_reuse_threshold(sd_img_gen_params->cache),
sd_img_gen_params->cache.start_percent,
sd_img_gen_params->cache.end_percent);
free(sample_params_str);
@ -3560,7 +3174,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
int64_t t4 = ggml_time_ms();
LOG_INFO("decode_first_stage completed, taking %.2fs", (t4 - t3) * 1.0f / 1000);
if (sd_ctx->sd->free_params_immediately && !sd_ctx->sd->use_tiny_autoencoder) {
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->first_stage_model->free_params_buffer();
}
@ -3609,15 +3223,15 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
if (sd_ctx->sd->first_stage_model) {
sd_ctx->sd->first_stage_model->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
}
if (sd_ctx->sd->tae_first_stage) {
sd_ctx->sd->tae_first_stage->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
if (sd_ctx->sd->preview_vae) {
sd_ctx->sd->preview_vae->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
}
} else {
int tile_size_x, tile_size_y;
float _overlap;
int latent_size_x = width / sd_ctx->sd->get_vae_scale_factor();
int latent_size_y = height / sd_ctx->sd->get_vae_scale_factor();
sd_ctx->sd->get_tile_sizes(tile_size_x, tile_size_y, _overlap, sd_img_gen_params->vae_tiling_params, latent_size_x, latent_size_y);
sd_ctx->sd->first_stage_model->get_tile_sizes(tile_size_x, tile_size_y, _overlap, sd_img_gen_params->vae_tiling_params, latent_size_x, latent_size_y);
// force disable circular padding for vae if tiling is enabled unless latent is smaller than tile size
// otherwise it will cause artifacts at the edges of the tiles
@ -3627,8 +3241,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
if (sd_ctx->sd->first_stage_model) {
sd_ctx->sd->first_stage_model->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
}
if (sd_ctx->sd->tae_first_stage) {
sd_ctx->sd->tae_first_stage->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
if (sd_ctx->sd->preview_vae) {
sd_ctx->sd->preview_vae->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
}
// disable circular tiling if it's enabled for the VAE
@ -4105,14 +3719,13 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
sd_image_to_ggml_tensor(sd_vid_gen_params->init_image, init_img);
init_img = ggml_reshape_4d(work_ctx, init_img, width, height, 1, 3);
auto init_image_latent = sd_ctx->sd->vae_encode(work_ctx, init_img); // [b*c, 1, h/16, w/16]
auto init_image_latent = sd_ctx->sd->encode_to_vae_latents(work_ctx, init_img); // [b*c, 1, h/16, w/16]
init_latent = sd_ctx->sd->generate_init_latent(work_ctx, width, height, frames, true);
denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1);
ggml_set_f32(denoise_mask, 1.f);
if (!sd_ctx->sd->use_tiny_autoencoder)
sd_ctx->sd->process_latent_out(init_latent);
init_latent = sd_ctx->sd->first_stage_model->diffusion_to_vae_latents(work_ctx, init_latent);
ggml_ext_tensor_iter(init_image_latent, [&](ggml_tensor* t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = ggml_ext_tensor_get_f32(t, i0, i1, i2, i3);
@ -4122,8 +3735,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
}
});
if (!sd_ctx->sd->use_tiny_autoencoder)
sd_ctx->sd->process_latent_in(init_latent);
init_latent = sd_ctx->sd->first_stage_model->vae_to_diffuison_latents(work_ctx, init_latent);
int64_t t2 = ggml_time_ms();
LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1);
@ -4346,7 +3958,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
struct ggml_tensor* vid = sd_ctx->sd->decode_first_stage(work_ctx, final_latent, true);
int64_t t5 = ggml_time_ms();
LOG_INFO("decode_first_stage completed, taking %.2fs", (t5 - t4) * 1.0f / 1000);
if (sd_ctx->sd->free_params_immediately && !sd_ctx->sd->use_tiny_autoencoder) {
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->first_stage_model->free_params_buffer();
}

View File

@ -442,11 +442,13 @@ protected:
bool decode_only;
SDVersion version;
public:
int z_channels = 16;
public:
TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2)
: decode_only(decode_only), version(version) {
int z_channels = 16;
int patch = 1;
int patch = 1;
if (version == VERSION_WAN2_2_TI2V) {
z_channels = 48;
patch = 2;
@ -494,10 +496,12 @@ protected:
bool decode_only;
bool taef2 = false;
public:
int z_channels = 4;
public:
TAESD(bool decode_only = true, SDVersion version = VERSION_SD1)
: decode_only(decode_only) {
int z_channels = 4;
bool use_midblock_gn = false;
taef2 = sd_version_is_flux2(version);
@ -533,20 +537,7 @@ public:
}
};
struct TinyAutoEncoder : public GGMLRunner {
TinyAutoEncoder(ggml_backend_t backend, bool offload_params_to_cpu)
: GGMLRunner(backend, offload_params_to_cpu) {}
virtual bool compute(const int n_threads,
struct ggml_tensor* z,
bool decode_graph,
struct ggml_tensor** output,
struct ggml_context* output_ctx = nullptr) = 0;
virtual bool load_from_file(const std::string& file_path, int n_threads) = 0;
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) = 0;
};
struct TinyImageAutoEncoder : public TinyAutoEncoder {
struct TinyImageAutoEncoder : public VAE {
TAESD taesd;
bool decode_only = false;
@ -558,7 +549,8 @@ struct TinyImageAutoEncoder : public TinyAutoEncoder {
SDVersion version = VERSION_SD1)
: decode_only(decoder_only),
taesd(decoder_only, version),
TinyAutoEncoder(backend, offload_params_to_cpu) {
VAE(version, backend, offload_params_to_cpu) {
scale_input = false;
taesd.init(params_ctx, tensor_storage_map, prefix);
}
@ -566,37 +558,26 @@ struct TinyImageAutoEncoder : public TinyAutoEncoder {
return "taesd";
}
bool load_from_file(const std::string& file_path, int n_threads) {
LOG_INFO("loading taesd from '%s', decode_only = %s", file_path.c_str(), decode_only ? "true" : "false");
alloc_params_buffer();
std::map<std::string, ggml_tensor*> taesd_tensors;
taesd.get_param_tensors(taesd_tensors);
std::set<std::string> ignore_tensors;
if (decode_only) {
ignore_tensors.insert("encoder.");
}
ModelLoader model_loader;
if (!model_loader.init_from_file_and_convert_name(file_path)) {
LOG_ERROR("init taesd model loader from file failed: '%s'", file_path.c_str());
return false;
}
bool success = model_loader.load_tensors(taesd_tensors, ignore_tensors, n_threads);
if (!success) {
LOG_ERROR("load tae tensors from model loader failed");
return false;
}
LOG_INFO("taesd model loaded");
return success;
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
taesd.get_param_tensors(tensors, prefix);
}
ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr<RNG> rng) {
return vae_output;
}
ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) {
return ggml_ext_dup_and_cpy_tensor(work_ctx, latents);
}
ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) {
return ggml_ext_dup_and_cpy_tensor(work_ctx, latents);
}
int get_encoder_output_channels(int input_channels) {
return taesd.z_channels;
}
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
z = to_backend(z);
@ -606,11 +587,11 @@ struct TinyImageAutoEncoder : public TinyAutoEncoder {
return gf;
}
bool compute(const int n_threads,
struct ggml_tensor* z,
bool decode_graph,
struct ggml_tensor** output,
struct ggml_context* output_ctx = nullptr) {
bool _compute(const int n_threads,
struct ggml_tensor* z,
bool decode_graph,
struct ggml_tensor** output,
struct ggml_context* output_ctx = nullptr) {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(z, decode_graph);
};
@ -619,7 +600,7 @@ struct TinyImageAutoEncoder : public TinyAutoEncoder {
}
};
struct TinyVideoAutoEncoder : public TinyAutoEncoder {
struct TinyVideoAutoEncoder : public VAE {
TAEHV taehv;
bool decode_only = false;
@ -631,7 +612,8 @@ struct TinyVideoAutoEncoder : public TinyAutoEncoder {
SDVersion version = VERSION_WAN2)
: decode_only(decoder_only),
taehv(decoder_only, version),
TinyAutoEncoder(backend, offload_params_to_cpu) {
VAE(version, backend, offload_params_to_cpu) {
scale_input = false;
taehv.init(params_ctx, tensor_storage_map, prefix);
}
@ -639,37 +621,26 @@ struct TinyVideoAutoEncoder : public TinyAutoEncoder {
return "taehv";
}
bool load_from_file(const std::string& file_path, int n_threads) {
LOG_INFO("loading taehv from '%s', decode_only = %s", file_path.c_str(), decode_only ? "true" : "false");
alloc_params_buffer();
std::map<std::string, ggml_tensor*> taehv_tensors;
taehv.get_param_tensors(taehv_tensors);
std::set<std::string> ignore_tensors;
if (decode_only) {
ignore_tensors.insert("encoder.");
}
ModelLoader model_loader;
if (!model_loader.init_from_file(file_path)) {
LOG_ERROR("init taehv model loader from file failed: '%s'", file_path.c_str());
return false;
}
bool success = model_loader.load_tensors(taehv_tensors, ignore_tensors, n_threads);
if (!success) {
LOG_ERROR("load tae tensors from model loader failed");
return false;
}
LOG_INFO("taehv model loaded");
return success;
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
taehv.get_param_tensors(tensors, prefix);
}
ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr<RNG> rng) {
return vae_output;
}
ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) {
return ggml_ext_dup_and_cpy_tensor(work_ctx, latents);
}
ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) {
return ggml_ext_dup_and_cpy_tensor(work_ctx, latents);
}
int get_encoder_output_channels(int input_channels) {
return taehv.z_channels;
}
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
z = to_backend(z);
@ -679,11 +650,11 @@ struct TinyVideoAutoEncoder : public TinyAutoEncoder {
return gf;
}
bool compute(const int n_threads,
struct ggml_tensor* z,
bool decode_graph,
struct ggml_tensor** output,
struct ggml_context* output_ctx = nullptr) {
bool _compute(const int n_threads,
struct ggml_tensor* z,
bool decode_graph,
struct ggml_tensor** output,
struct ggml_context* output_ctx = nullptr) {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(z, decode_graph);
};

View File

@ -3,635 +3,206 @@
#include "common_block.hpp"
/*================================================== AutoEncoderKL ===================================================*/
#define VAE_GRAPH_SIZE 20480
class ResnetBlock : public UnaryBlock {
protected:
int64_t in_channels;
int64_t out_channels;
public:
ResnetBlock(int64_t in_channels,
int64_t out_channels)
: in_channels(in_channels),
out_channels(out_channels) {
// temb_channels is always 0
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(out_channels));
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
if (out_channels != in_channels) {
blocks["nin_shortcut"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, {1, 1}));
}
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
// x: [N, in_channels, h, w]
// t_emb is always None
auto norm1 = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm1"]);
auto conv1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv1"]);
auto norm2 = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm2"]);
auto conv2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv2"]);
auto h = x;
h = norm1->forward(ctx, h);
h = ggml_silu_inplace(ctx->ggml_ctx, h); // swish
h = conv1->forward(ctx, h);
// return h;
h = norm2->forward(ctx, h);
h = ggml_silu_inplace(ctx->ggml_ctx, h); // swish
// dropout, skip for inference
h = conv2->forward(ctx, h);
// skip connection
if (out_channels != in_channels) {
auto nin_shortcut = std::dynamic_pointer_cast<Conv2d>(blocks["nin_shortcut"]);
x = nin_shortcut->forward(ctx, x); // [N, out_channels, h, w]
}
h = ggml_add(ctx->ggml_ctx, h, x);
return h; // [N, out_channels, h, w]
}
};
class AttnBlock : public UnaryBlock {
protected:
int64_t in_channels;
bool use_linear;
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") {
auto iter = tensor_storage_map.find(prefix + "proj_out.weight");
if (iter != tensor_storage_map.end()) {
if (iter->second.n_dims == 4 && use_linear) {
use_linear = false;
blocks["q"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
blocks["k"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
blocks["v"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
blocks["proj_out"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
} else if (iter->second.n_dims == 2 && !use_linear) {
use_linear = true;
blocks["q"] = std::make_shared<Linear>(in_channels, in_channels);
blocks["k"] = std::make_shared<Linear>(in_channels, in_channels);
blocks["v"] = std::make_shared<Linear>(in_channels, in_channels);
blocks["proj_out"] = std::make_shared<Linear>(in_channels, in_channels);
}
}
}
public:
AttnBlock(int64_t in_channels, bool use_linear)
: in_channels(in_channels), use_linear(use_linear) {
blocks["norm"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
if (use_linear) {
blocks["q"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
blocks["k"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
blocks["v"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
} else {
blocks["q"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
blocks["k"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
blocks["v"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
}
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
// x: [N, in_channels, h, w]
auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]);
auto q_proj = std::dynamic_pointer_cast<UnaryBlock>(blocks["q"]);
auto k_proj = std::dynamic_pointer_cast<UnaryBlock>(blocks["k"]);
auto v_proj = std::dynamic_pointer_cast<UnaryBlock>(blocks["v"]);
auto proj_out = std::dynamic_pointer_cast<UnaryBlock>(blocks["proj_out"]);
auto h_ = norm->forward(ctx, x);
const int64_t n = h_->ne[3];
const int64_t c = h_->ne[2];
const int64_t h = h_->ne[1];
const int64_t w = h_->ne[0];
ggml_tensor* q;
ggml_tensor* k;
ggml_tensor* v;
if (use_linear) {
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 2, 0, 3)); // [N, h, w, in_channels]
h_ = ggml_reshape_3d(ctx->ggml_ctx, h_, c, h * w, n); // [N, h * w, in_channels]
q = q_proj->forward(ctx, h_); // [N, h * w, in_channels]
k = k_proj->forward(ctx, h_); // [N, h * w, in_channels]
v = v_proj->forward(ctx, h_); // [N, h * w, in_channels]
} else {
q = q_proj->forward(ctx, h_); // [N, in_channels, h, w]
q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels]
q = ggml_reshape_3d(ctx->ggml_ctx, q, c, h * w, n); // [N, h * w, in_channels]
k = k_proj->forward(ctx, h_); // [N, in_channels, h, w]
k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels]
k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [N, h * w, in_channels]
v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
v = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, v, 1, 2, 0, 3)); // [N, h, w, in_channels]
v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels]
}
h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, ctx->flash_attn_enabled);
if (use_linear) {
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w]
} else {
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w]
h_ = proj_out->forward(ctx, h_); // [N, in_channels, h, w]
}
h_ = ggml_add(ctx->ggml_ctx, h_, x);
return h_;
}
};
class AE3DConv : public Conv2d {
public:
AE3DConv(int64_t in_channels,
int64_t out_channels,
std::pair<int, int> kernel_size,
int video_kernel_size = 3,
std::pair<int, int> stride = {1, 1},
std::pair<int, int> padding = {0, 0},
std::pair<int, int> dilation = {1, 1},
bool bias = true)
: Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias) {
int kernel_padding = video_kernel_size / 2;
blocks["time_mix_conv"] = std::shared_ptr<GGMLBlock>(new Conv3d(out_channels,
out_channels,
{video_kernel_size, 1, 1},
{1, 1, 1},
{kernel_padding, 0, 0}));
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x) override {
// timesteps always None
// skip_video always False
// x: [N, IC, IH, IW]
// result: [N, OC, OH, OW]
auto time_mix_conv = std::dynamic_pointer_cast<Conv3d>(blocks["time_mix_conv"]);
x = Conv2d::forward(ctx, x);
// timesteps = x.shape[0]
// x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
// x = conv3d(x)
// return rearrange(x, "b c t h w -> (b t) c h w")
int64_t T = x->ne[3];
int64_t B = x->ne[3] / T;
int64_t C = x->ne[2];
int64_t H = x->ne[1];
int64_t W = x->ne[0];
x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
x = time_mix_conv->forward(ctx, x); // [B, OC, T, OH * OW]
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
return x; // [B*T, OC, OH, OW]
}
};
class VideoResnetBlock : public ResnetBlock {
protected:
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
enum ggml_type wtype = get_type(prefix + "mix_factor", tensor_storage_map, GGML_TYPE_F32);
params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1);
}
float get_alpha() {
float alpha = ggml_ext_backend_tensor_get_f32(params["mix_factor"]);
return sigmoid(alpha);
}
public:
VideoResnetBlock(int64_t in_channels,
int64_t out_channels,
int video_kernel_size = 3)
: ResnetBlock(in_channels, out_channels) {
// merge_strategy is always learned
blocks["time_stack"] = std::shared_ptr<GGMLBlock>(new ResBlock(out_channels, 0, out_channels, {video_kernel_size, 1}, 3, false, true));
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
// x: [N, in_channels, h, w] aka [b*t, in_channels, h, w]
// return: [N, out_channels, h, w] aka [b*t, out_channels, h, w]
// t_emb is always None
// skip_video is always False
// timesteps is always None
auto time_stack = std::dynamic_pointer_cast<ResBlock>(blocks["time_stack"]);
x = ResnetBlock::forward(ctx, x); // [N, out_channels, h, w]
// return x;
int64_t T = x->ne[3];
int64_t B = x->ne[3] / T;
int64_t C = x->ne[2];
int64_t H = x->ne[1];
int64_t W = x->ne[0];
x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
auto x_mix = x;
x = time_stack->forward(ctx, x); // b t c (h w)
float alpha = get_alpha();
x = ggml_add(ctx->ggml_ctx,
ggml_ext_scale(ctx->ggml_ctx, x, alpha),
ggml_ext_scale(ctx->ggml_ctx, x_mix, 1.0f - alpha));
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
return x;
}
};
// ldm.modules.diffusionmodules.model.Encoder
class Encoder : public GGMLBlock {
protected:
int ch = 128;
std::vector<int> ch_mult = {1, 2, 4, 4};
int num_res_blocks = 2;
int in_channels = 3;
int z_channels = 4;
bool double_z = true;
public:
Encoder(int ch,
std::vector<int> ch_mult,
int num_res_blocks,
int in_channels,
int z_channels,
bool double_z = true,
bool use_linear_projection = false)
: ch(ch),
ch_mult(ch_mult),
num_res_blocks(num_res_blocks),
in_channels(in_channels),
z_channels(z_channels),
double_z(double_z) {
blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, ch, {3, 3}, {1, 1}, {1, 1}));
size_t num_resolutions = ch_mult.size();
int block_in = 1;
for (int i = 0; i < num_resolutions; i++) {
if (i == 0) {
block_in = ch;
} else {
block_in = ch * ch_mult[i - 1];
}
int block_out = ch * ch_mult[i];
for (int j = 0; j < num_res_blocks; j++) {
std::string name = "down." + std::to_string(i) + ".block." + std::to_string(j);
blocks[name] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_out));
block_in = block_out;
}
if (i != num_resolutions - 1) {
std::string name = "down." + std::to_string(i) + ".downsample";
blocks[name] = std::shared_ptr<GGMLBlock>(new DownSampleBlock(block_in, block_in, true));
}
}
blocks["mid.block_1"] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_in));
blocks["mid.attn_1"] = std::shared_ptr<GGMLBlock>(new AttnBlock(block_in, use_linear_projection));
blocks["mid.block_2"] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_in));
blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(block_in));
blocks["conv_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(block_in, double_z ? z_channels * 2 : z_channels, {3, 3}, {1, 1}, {1, 1}));
}
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
// x: [N, in_channels, h, w]
auto conv_in = std::dynamic_pointer_cast<Conv2d>(blocks["conv_in"]);
auto mid_block_1 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_1"]);
auto mid_attn_1 = std::dynamic_pointer_cast<AttnBlock>(blocks["mid.attn_1"]);
auto mid_block_2 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_2"]);
auto norm_out = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm_out"]);
auto conv_out = std::dynamic_pointer_cast<Conv2d>(blocks["conv_out"]);
auto h = conv_in->forward(ctx, x); // [N, ch, h, w]
// downsampling
size_t num_resolutions = ch_mult.size();
for (int i = 0; i < num_resolutions; i++) {
for (int j = 0; j < num_res_blocks; j++) {
std::string name = "down." + std::to_string(i) + ".block." + std::to_string(j);
auto down_block = std::dynamic_pointer_cast<ResnetBlock>(blocks[name]);
h = down_block->forward(ctx, h);
}
if (i != num_resolutions - 1) {
std::string name = "down." + std::to_string(i) + ".downsample";
auto down_sample = std::dynamic_pointer_cast<DownSampleBlock>(blocks[name]);
h = down_sample->forward(ctx, h);
}
}
// middle
h = mid_block_1->forward(ctx, h);
h = mid_attn_1->forward(ctx, h);
h = mid_block_2->forward(ctx, h); // [N, block_in, h, w]
// end
h = norm_out->forward(ctx, h);
h = ggml_silu_inplace(ctx->ggml_ctx, h); // nonlinearity/swish
h = conv_out->forward(ctx, h); // [N, z_channels*2, h, w]
return h;
}
};
// ldm.modules.diffusionmodules.model.Decoder
class Decoder : public GGMLBlock {
protected:
int ch = 128;
int out_ch = 3;
std::vector<int> ch_mult = {1, 2, 4, 4};
int num_res_blocks = 2;
int z_channels = 4;
bool video_decoder = false;
int video_kernel_size = 3;
virtual std::shared_ptr<GGMLBlock> get_conv_out(int64_t in_channels,
int64_t out_channels,
std::pair<int, int> kernel_size,
std::pair<int, int> stride = {1, 1},
std::pair<int, int> padding = {0, 0}) {
if (video_decoder) {
return std::shared_ptr<GGMLBlock>(new AE3DConv(in_channels, out_channels, kernel_size, video_kernel_size, stride, padding));
} else {
return std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, kernel_size, stride, padding));
}
}
virtual std::shared_ptr<GGMLBlock> get_resnet_block(int64_t in_channels,
int64_t out_channels) {
if (video_decoder) {
return std::shared_ptr<GGMLBlock>(new VideoResnetBlock(in_channels, out_channels, video_kernel_size));
} else {
return std::shared_ptr<GGMLBlock>(new ResnetBlock(in_channels, out_channels));
}
}
public:
Decoder(int ch,
int out_ch,
std::vector<int> ch_mult,
int num_res_blocks,
int z_channels,
bool use_linear_projection = false,
bool video_decoder = false,
int video_kernel_size = 3)
: ch(ch),
out_ch(out_ch),
ch_mult(ch_mult),
num_res_blocks(num_res_blocks),
z_channels(z_channels),
video_decoder(video_decoder),
video_kernel_size(video_kernel_size) {
int num_resolutions = static_cast<int>(ch_mult.size());
int block_in = ch * ch_mult[num_resolutions - 1];
blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, block_in, {3, 3}, {1, 1}, {1, 1}));
blocks["mid.block_1"] = get_resnet_block(block_in, block_in);
blocks["mid.attn_1"] = std::shared_ptr<GGMLBlock>(new AttnBlock(block_in, use_linear_projection));
blocks["mid.block_2"] = get_resnet_block(block_in, block_in);
for (int i = num_resolutions - 1; i >= 0; i--) {
int mult = ch_mult[i];
int block_out = ch * mult;
for (int j = 0; j < num_res_blocks + 1; j++) {
std::string name = "up." + std::to_string(i) + ".block." + std::to_string(j);
blocks[name] = get_resnet_block(block_in, block_out);
block_in = block_out;
}
if (i != 0) {
std::string name = "up." + std::to_string(i) + ".upsample";
blocks[name] = std::shared_ptr<GGMLBlock>(new UpSampleBlock(block_in, block_in));
}
}
blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(block_in));
blocks["conv_out"] = get_conv_out(block_in, out_ch, {3, 3}, {1, 1}, {1, 1});
}
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
// z: [N, z_channels, h, w]
// alpha is always 0
// merge_strategy is always learned
// time_mode is always conv-only, so we need to replace conv_out_op/resnet_op to AE3DConv/VideoResBlock
// AttnVideoBlock will not be used
auto conv_in = std::dynamic_pointer_cast<Conv2d>(blocks["conv_in"]);
auto mid_block_1 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_1"]);
auto mid_attn_1 = std::dynamic_pointer_cast<AttnBlock>(blocks["mid.attn_1"]);
auto mid_block_2 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_2"]);
auto norm_out = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm_out"]);
auto conv_out = std::dynamic_pointer_cast<Conv2d>(blocks["conv_out"]);
// conv_in
auto h = conv_in->forward(ctx, z); // [N, block_in, h, w]
// middle
h = mid_block_1->forward(ctx, h);
// return h;
h = mid_attn_1->forward(ctx, h);
h = mid_block_2->forward(ctx, h); // [N, block_in, h, w]
// upsampling
int num_resolutions = static_cast<int>(ch_mult.size());
for (int i = num_resolutions - 1; i >= 0; i--) {
for (int j = 0; j < num_res_blocks + 1; j++) {
std::string name = "up." + std::to_string(i) + ".block." + std::to_string(j);
auto up_block = std::dynamic_pointer_cast<ResnetBlock>(blocks[name]);
h = up_block->forward(ctx, h);
}
if (i != 0) {
std::string name = "up." + std::to_string(i) + ".upsample";
auto up_sample = std::dynamic_pointer_cast<UpSampleBlock>(blocks[name]);
h = up_sample->forward(ctx, h);
}
}
h = norm_out->forward(ctx, h);
h = ggml_silu_inplace(ctx->ggml_ctx, h); // nonlinearity/swish
h = conv_out->forward(ctx, h); // [N, out_ch, h*8, w*8]
return h;
}
};
// ldm.models.autoencoder.AutoencoderKL
class AutoencodingEngine : public GGMLBlock {
struct VAE : public GGMLRunner {
protected:
SDVersion version;
bool decode_only = true;
bool use_video_decoder = false;
bool use_quant = true;
int embed_dim = 4;
struct {
int z_channels = 4;
int resolution = 256;
int in_channels = 3;
int out_ch = 3;
int ch = 128;
std::vector<int> ch_mult = {1, 2, 4, 4};
int num_res_blocks = 2;
bool double_z = true;
} dd_config;
bool scale_input = true;
virtual bool _compute(const int n_threads,
struct ggml_tensor* z,
bool decode_graph,
struct ggml_tensor** output,
struct ggml_context* output_ctx) = 0;
public:
AutoencodingEngine(SDVersion version = VERSION_SD1,
bool decode_only = true,
bool use_linear_projection = false,
bool use_video_decoder = false)
: version(version), decode_only(decode_only), use_video_decoder(use_video_decoder) {
if (sd_version_is_dit(version)) {
if (sd_version_is_flux2(version)) {
dd_config.z_channels = 32;
embed_dim = 32;
VAE(SDVersion version, ggml_backend_t backend, bool offload_params_to_cpu)
: version(version), GGMLRunner(backend, offload_params_to_cpu) {}
int get_scale_factor() {
int scale_factor = 8;
if (version == VERSION_WAN2_2_TI2V) {
scale_factor = 16;
} else if (sd_version_is_flux2(version)) {
scale_factor = 16;
} else if (version == VERSION_CHROMA_RADIANCE) {
scale_factor = 1;
}
return scale_factor;
}
virtual int get_encoder_output_channels(int input_channels) = 0;
void get_tile_sizes(int& tile_size_x,
int& tile_size_y,
float& tile_overlap,
const sd_tiling_params_t& params,
int64_t latent_x,
int64_t latent_y,
float encoding_factor = 1.0f) {
tile_overlap = std::max(std::min(params.target_overlap, 0.5f), 0.0f);
auto get_tile_size = [&](int requested_size, float factor, int64_t latent_size) {
const int default_tile_size = 32;
const int min_tile_dimension = 4;
int tile_size = default_tile_size;
// factor <= 1 means simple fraction of the latent dimension
// factor > 1 means number of tiles across that dimension
if (factor > 0.f) {
if (factor > 1.0)
factor = 1 / (factor - factor * tile_overlap + tile_overlap);
tile_size = static_cast<int>(std::round(latent_size * factor));
} else if (requested_size >= min_tile_dimension) {
tile_size = requested_size;
}
tile_size = static_cast<int>(tile_size * encoding_factor);
return std::max(std::min(tile_size, static_cast<int>(latent_size)), min_tile_dimension);
};
tile_size_x = get_tile_size(params.tile_size_x, params.rel_size_x, latent_x);
tile_size_y = get_tile_size(params.tile_size_y, params.rel_size_y, latent_y);
}
ggml_tensor* encode(int n_threads,
ggml_context* work_ctx,
ggml_tensor* x,
sd_tiling_params_t tiling_params,
bool circular_x = false,
bool circular_y = false) {
int64_t t0 = ggml_time_ms();
ggml_tensor* result = nullptr;
const int scale_factor = get_scale_factor();
int64_t W = x->ne[0] / scale_factor;
int64_t H = x->ne[1] / scale_factor;
int channel_dim = sd_version_is_wan(version) ? 3 : 2;
int64_t C = get_encoder_output_channels(static_cast<int>(x->ne[channel_dim]));
int64_t ne2;
int64_t ne3;
if (sd_version_is_wan(version)) {
int64_t T = x->ne[2];
ne2 = (T - 1) / 4 + 1;
ne3 = C;
} else {
ne2 = C;
ne3 = x->ne[3];
}
result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, ne2, ne3);
if (scale_input) {
scale_to_minus1_1(x);
}
if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]);
}
if (tiling_params.enabled) {
float tile_overlap;
int tile_size_x, tile_size_y;
// multiply tile size for encode to keep the compute buffer size consistent
get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, tiling_params, W, H, 1.30539f);
LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
return _compute(n_threads, in, false, &out, work_ctx);
};
sd_tiling_non_square(x, result, scale_factor, tile_size_x, tile_size_y, tile_overlap, circular_x, circular_y, on_tiling);
} else {
_compute(n_threads, x, false, &result, work_ctx);
}
free_compute_buffer();
int64_t t1 = ggml_time_ms();
LOG_DEBUG("computing vae encode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
return result;
}
ggml_tensor* decode(int n_threads,
ggml_context* work_ctx,
ggml_tensor* x,
sd_tiling_params_t tiling_params,
bool decode_video = false,
bool circular_x = false,
bool circular_y = false,
ggml_tensor* result = nullptr,
bool silent = false) {
const int scale_factor = get_scale_factor();
int64_t W = x->ne[0] * scale_factor;
int64_t H = x->ne[1] * scale_factor;
int64_t C = 3;
if (result == nullptr) {
if (decode_video) {
int64_t T = x->ne[2];
if (sd_version_is_wan(version)) {
T = ((T - 1) * 4) + 1;
}
result = ggml_new_tensor_4d(work_ctx,
GGML_TYPE_F32,
W,
H,
T,
3);
} else {
use_quant = false;
dd_config.z_channels = 16;
result = ggml_new_tensor_4d(work_ctx,
GGML_TYPE_F32,
W,
H,
C,
x->ne[3]);
}
}
if (use_video_decoder) {
use_quant = false;
int64_t t0 = ggml_time_ms();
if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]);
}
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder(dd_config.ch,
dd_config.out_ch,
dd_config.ch_mult,
dd_config.num_res_blocks,
dd_config.z_channels,
use_linear_projection,
use_video_decoder));
if (use_quant) {
blocks["post_quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(dd_config.z_channels,
embed_dim,
{1, 1}));
}
if (!decode_only) {
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new Encoder(dd_config.ch,
dd_config.ch_mult,
dd_config.num_res_blocks,
dd_config.in_channels,
dd_config.z_channels,
dd_config.double_z,
use_linear_projection));
if (use_quant) {
int factor = dd_config.double_z ? 2 : 1;
if (tiling_params.enabled) {
float tile_overlap;
int tile_size_x, tile_size_y;
get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, tiling_params, x->ne[0], x->ne[1]);
blocks["quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(embed_dim * factor,
dd_config.z_channels * factor,
{1, 1}));
if (!silent) {
LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
}
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
return _compute(n_threads, in, true, &out, nullptr);
};
sd_tiling_non_square(x, result, scale_factor, tile_size_x, tile_size_y, tile_overlap, circular_x, circular_y, on_tiling, silent);
} else {
if (!_compute(n_threads, x, true, &result, work_ctx)) {
LOG_ERROR("Failed to decode latetnts");
free_compute_buffer();
return nullptr;
}
}
free_compute_buffer();
if (scale_input) {
scale_to_0_1(result);
}
int64_t t1 = ggml_time_ms();
LOG_DEBUG("computing vae decode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
ggml_ext_tensor_clamp_inplace(result, 0.0f, 1.0f);
return result;
}
struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
// z: [N, z_channels, h, w]
if (sd_version_is_flux2(version)) {
// [N, C*p*p, h, w] -> [N, C, h*p, w*p]
int64_t p = 2;
int64_t N = z->ne[3];
int64_t C = z->ne[2] / p / p;
int64_t h = z->ne[1];
int64_t w = z->ne[0];
int64_t H = h * p;
int64_t W = w * p;
z = ggml_reshape_4d(ctx->ggml_ctx, z, w * h, p * p, C, N); // [N, C, p*p, h*w]
z = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, z, 1, 0, 2, 3)); // [N, C, h*w, p*p]
z = ggml_reshape_4d(ctx->ggml_ctx, z, p, p, w, h * C * N); // [N*C*h, w, p, p]
z = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, z, 0, 2, 1, 3)); // [N*C*h, p, w, p]
z = ggml_reshape_4d(ctx->ggml_ctx, z, W, H, C, N); // [N, C, h*p, w*p]
}
if (use_quant) {
auto post_quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["post_quant_conv"]);
z = post_quant_conv->forward(ctx, z); // [N, z_channels, h, w]
}
auto decoder = std::dynamic_pointer_cast<Decoder>(blocks["decoder"]);
ggml_set_name(z, "bench-start");
auto h = decoder->forward(ctx, z);
ggml_set_name(h, "bench-end");
return h;
}
struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
// x: [N, in_channels, h, w]
auto encoder = std::dynamic_pointer_cast<Encoder>(blocks["encoder"]);
auto z = encoder->forward(ctx, x); // [N, 2*z_channels, h/8, w/8]
if (use_quant) {
auto quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["quant_conv"]);
z = quant_conv->forward(ctx, z); // [N, 2*embed_dim, h/8, w/8]
}
if (sd_version_is_flux2(version)) {
z = ggml_ext_chunk(ctx->ggml_ctx, z, 2, 2)[0];
// [N, C, H, W] -> [N, C*p*p, H/p, W/p]
int64_t p = 2;
int64_t N = z->ne[3];
int64_t C = z->ne[2];
int64_t H = z->ne[1];
int64_t W = z->ne[0];
int64_t h = H / p;
int64_t w = W / p;
z = ggml_reshape_4d(ctx->ggml_ctx, z, p, w, p, h * C * N); // [N*C*h, p, w, p]
z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 0, 2, 1, 3)); // [N*C*h, w, p, p]
z = ggml_reshape_4d(ctx->ggml_ctx, z, p * p, w * h, C, N); // [N, C, h*w, p*p]
z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 1, 0, 2, 3)); // [N, C, p*p, h*w]
z = ggml_reshape_4d(ctx->ggml_ctx, z, w, h, p * p * C, N); // [N, C*p*p, h*w]
}
return z;
}
};
struct VAE : public GGMLRunner {
VAE(ggml_backend_t backend, bool offload_params_to_cpu)
: GGMLRunner(backend, offload_params_to_cpu) {}
virtual bool compute(const int n_threads,
struct ggml_tensor* z,
bool decode_graph,
struct ggml_tensor** output,
struct ggml_context* output_ctx) = 0;
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) = 0;
virtual ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr<RNG> rng) = 0;
virtual ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) = 0;
virtual ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) = 0;
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) = 0;
virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); };
};
struct FakeVAE : public VAE {
FakeVAE(ggml_backend_t backend, bool offload_params_to_cpu)
: VAE(backend, offload_params_to_cpu) {}
bool compute(const int n_threads,
struct ggml_tensor* z,
bool decode_graph,
struct ggml_tensor** output,
struct ggml_context* output_ctx) override {
FakeVAE(SDVersion version, ggml_backend_t backend, bool offload_params_to_cpu)
: VAE(version, backend, offload_params_to_cpu) {}
int get_encoder_output_channels(int input_channels) {
return input_channels;
}
bool _compute(const int n_threads,
struct ggml_tensor* z,
bool decode_graph,
struct ggml_tensor** output,
struct ggml_context* output_ctx) override {
if (*output == nullptr && output_ctx != nullptr) {
*output = ggml_dup_tensor(output_ctx, z);
}
@ -642,6 +213,18 @@ struct FakeVAE : public VAE {
return true;
}
ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr<RNG> rng) {
return vae_output;
}
ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) {
return ggml_ext_dup_and_cpy_tensor(work_ctx, latents);
}
ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) {
return ggml_ext_dup_and_cpy_tensor(work_ctx, latents);
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) override {}
std::string get_desc() override {
@ -649,126 +232,4 @@ struct FakeVAE : public VAE {
}
};
struct AutoEncoderKL : public VAE {
bool decode_only = true;
AutoencodingEngine ae;
AutoEncoderKL(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map,
const std::string prefix,
bool decode_only = false,
bool use_video_decoder = false,
SDVersion version = VERSION_SD1)
: decode_only(decode_only), VAE(backend, offload_params_to_cpu) {
bool use_linear_projection = false;
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (!starts_with(name, prefix)) {
continue;
}
if (ends_with(name, "attn_1.proj_out.weight")) {
if (tensor_storage.n_dims == 2) {
use_linear_projection = true;
}
break;
}
}
ae = AutoencodingEngine(version, decode_only, use_linear_projection, use_video_decoder);
ae.init(params_ctx, tensor_storage_map, prefix);
}
void set_conv2d_scale(float scale) override {
std::vector<GGMLBlock*> blocks;
ae.get_all_blocks(blocks);
for (auto block : blocks) {
if (block->get_desc() == "Conv2d") {
auto conv_block = (Conv2d*)block;
conv_block->set_scale(scale);
}
}
}
std::string get_desc() override {
return "vae";
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) override {
ae.get_param_tensors(tensors, prefix);
}
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
z = to_backend(z);
auto runner_ctx = get_context();
struct ggml_tensor* out = decode_graph ? ae.decode(&runner_ctx, z) : ae.encode(&runner_ctx, z);
ggml_build_forward_expand(gf, out);
return gf;
}
bool compute(const int n_threads,
struct ggml_tensor* z,
bool decode_graph,
struct ggml_tensor** output,
struct ggml_context* output_ctx = nullptr) override {
GGML_ASSERT(!decode_only || decode_graph);
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(z, decode_graph);
};
// ggml_set_f32(z, 0.5f);
// print_ggml_tensor(z);
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
}
void test() {
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
params.mem_buffer = nullptr;
params.no_alloc = false;
struct ggml_context* work_ctx = ggml_init(params);
GGML_ASSERT(work_ctx != nullptr);
{
// CPU, x{1, 3, 64, 64}: Pass
// CUDA, x{1, 3, 64, 64}: Pass, but sill get wrong result for some image, may be due to interlnal nan
// CPU, x{2, 3, 64, 64}: Wrong result
// CUDA, x{2, 3, 64, 64}: Wrong result, and different from CPU result
auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 64, 64, 3, 2);
ggml_set_f32(x, 0.5f);
print_ggml_tensor(x);
struct ggml_tensor* out = nullptr;
int64_t t0 = ggml_time_ms();
compute(8, x, false, &out, work_ctx);
int64_t t1 = ggml_time_ms();
print_ggml_tensor(out);
LOG_DEBUG("encode test done in %lldms", t1 - t0);
}
if (false) {
// CPU, z{1, 4, 8, 8}: Pass
// CUDA, z{1, 4, 8, 8}: Pass
// CPU, z{3, 4, 8, 8}: Wrong result
// CUDA, z{3, 4, 8, 8}: Wrong result, and different from CPU result
auto z = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 4, 1);
ggml_set_f32(z, 0.5f);
print_ggml_tensor(z);
struct ggml_tensor* out = nullptr;
int64_t t0 = ggml_time_ms();
compute(8, z, true, &out, work_ctx);
int64_t t1 = ggml_time_ms();
print_ggml_tensor(out);
LOG_DEBUG("decode test done in %lldms", t1 - t0);
}
};
};
#endif
#endif // __VAE_HPP__

View File

@ -1109,7 +1109,8 @@ namespace WAN {
};
struct WanVAERunner : public VAE {
bool decode_only = true;
float scale_factor = 1.0f;
bool decode_only = true;
WanVAE ae;
WanVAERunner(ggml_backend_t backend,
@ -1118,7 +1119,7 @@ namespace WAN {
const std::string prefix = "",
bool decode_only = false,
SDVersion version = VERSION_WAN2)
: decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V), VAE(backend, offload_params_to_cpu) {
: decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V), VAE(version, backend, offload_params_to_cpu) {
ae.init(params_ctx, tensor_storage_map, prefix);
}
@ -1130,6 +1131,101 @@ namespace WAN {
ae.get_param_tensors(tensors, prefix);
}
ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr<RNG> rng) {
return vae_output;
}
void get_latents_mean_std_vec(ggml_tensor* latents, int channel_dim, std::vector<float>& latents_mean_vec, std::vector<float>& latents_std_vec) {
GGML_ASSERT(latents->ne[channel_dim] == 16 || latents->ne[channel_dim] == 48);
if (latents->ne[channel_dim] == 16) { // Wan2.1 VAE
latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f,
0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f};
latents_std_vec = {2.8184f, 1.4541f, 2.3275f, 2.6558f, 1.2196f, 1.7708f, 2.6052f, 2.0743f,
3.2687f, 2.1526f, 2.8652f, 1.5579f, 1.6382f, 1.1253f, 2.8251f, 1.9160f};
} else if (latents->ne[channel_dim] == 48) { // Wan2.2 VAE
latents_mean_vec = {-0.2289f, -0.0052f, -0.1323f, -0.2339f, -0.2799f, 0.0174f, 0.1838f, 0.1557f,
-0.1382f, 0.0542f, 0.2813f, 0.0891f, 0.1570f, -0.0098f, 0.0375f, -0.1825f,
-0.2246f, -0.1207f, -0.0698f, 0.5109f, 0.2665f, -0.2108f, -0.2158f, 0.2502f,
-0.2055f, -0.0322f, 0.1109f, 0.1567f, -0.0729f, 0.0899f, -0.2799f, -0.1230f,
-0.0313f, -0.1649f, 0.0117f, 0.0723f, -0.2839f, -0.2083f, -0.0520f, 0.3748f,
0.0152f, 0.1957f, 0.1433f, -0.2944f, 0.3573f, -0.0548f, -0.1681f, -0.0667f};
latents_std_vec = {
0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f,
0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f,
0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f,
0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f,
0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f,
0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f};
}
}
ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) {
ggml_tensor* vae_latents = ggml_dup(work_ctx, latents);
int channel_dim = sd_version_is_wan(version) ? 3 : 2;
std::vector<float> latents_mean_vec;
std::vector<float> latents_std_vec;
get_latents_mean_std_vec(latents, channel_dim, latents_mean_vec, latents_std_vec);
float mean;
float std_;
for (int i = 0; i < latents->ne[3]; i++) {
if (channel_dim == 3) {
mean = latents_mean_vec[i];
std_ = latents_std_vec[i];
}
for (int j = 0; j < latents->ne[2]; j++) {
if (channel_dim == 2) {
mean = latents_mean_vec[j];
std_ = latents_std_vec[j];
}
for (int k = 0; k < latents->ne[1]; k++) {
for (int l = 0; l < latents->ne[0]; l++) {
float value = ggml_ext_tensor_get_f32(latents, l, k, j, i);
value = value * std_ / scale_factor + mean;
ggml_ext_tensor_set_f32(vae_latents, value, l, k, j, i);
}
}
}
}
return vae_latents;
}
ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) {
ggml_tensor* diffusion_latents = ggml_dup(work_ctx, latents);
int channel_dim = sd_version_is_wan(version) ? 3 : 2;
std::vector<float> latents_mean_vec;
std::vector<float> latents_std_vec;
get_latents_mean_std_vec(latents, channel_dim, latents_mean_vec, latents_std_vec);
float mean;
float std_;
for (int i = 0; i < latents->ne[3]; i++) {
if (channel_dim == 3) {
mean = latents_mean_vec[i];
std_ = latents_std_vec[i];
}
for (int j = 0; j < latents->ne[2]; j++) {
if (channel_dim == 2) {
mean = latents_mean_vec[j];
std_ = latents_std_vec[j];
}
for (int k = 0; k < latents->ne[1]; k++) {
for (int l = 0; l < latents->ne[0]; l++) {
float value = ggml_ext_tensor_get_f32(latents, l, k, j, i);
value = (value - mean) * scale_factor / std_;
ggml_ext_tensor_set_f32(diffusion_latents, value, l, k, j, i);
}
}
}
}
return diffusion_latents;
}
int get_encoder_output_channels(int input_channels) {
return static_cast<int>(ae.z_dim);
}
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
struct ggml_cgraph* gf = new_graph_custom(10240 * z->ne[2]);
@ -1173,11 +1269,11 @@ namespace WAN {
return gf;
}
bool compute(const int n_threads,
struct ggml_tensor* z,
bool decode_graph,
struct ggml_tensor** output,
struct ggml_context* output_ctx = nullptr) override {
bool _compute(const int n_threads,
struct ggml_tensor* z,
bool decode_graph,
struct ggml_tensor** output,
struct ggml_context* output_ctx = nullptr) override {
if (true) {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(z, decode_graph);
@ -1249,7 +1345,7 @@ namespace WAN {
struct ggml_tensor* out = nullptr;
int64_t t0 = ggml_time_ms();
compute(8, z, true, &out, work_ctx);
_compute(8, z, true, &out, work_ctx);
int64_t t1 = ggml_time_ms();
print_ggml_tensor(out);