mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-03-24 02:08:51 +00:00
Compare commits
7 Commits
d6dd6d7b55
...
61d8331ef3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
61d8331ef3 | ||
|
|
acc3bf1fdc | ||
|
|
83eabd7c01 | ||
|
|
630ee03f23 | ||
|
|
f6968bc589 | ||
|
|
adfef62900 | ||
|
|
6fa7ca9317 |
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@ -162,7 +162,7 @@ jobs:
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
variant: [musa, sycl, vulkan]
|
||||
variant: [musa, sycl, vulkan, cuda]
|
||||
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
|
||||
@ -36,7 +36,6 @@ option(SD_VULKAN "sd: vulkan backend" OFF)
|
||||
option(SD_OPENCL "sd: opencl backend" OFF)
|
||||
option(SD_SYCL "sd: sycl backend" OFF)
|
||||
option(SD_MUSA "sd: musa backend" OFF)
|
||||
option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF)
|
||||
option(SD_BUILD_SHARED_LIBS "sd: build shared libs" OFF)
|
||||
option(SD_BUILD_SHARED_GGML_LIB "sd: build ggml as a separate shared lib" OFF)
|
||||
option(SD_USE_SYSTEM_GGML "sd: use system-installed GGML library" OFF)
|
||||
@ -70,18 +69,12 @@ if (SD_HIPBLAS)
|
||||
message("-- Use HIPBLAS as backend stable-diffusion")
|
||||
set(GGML_HIP ON)
|
||||
add_definitions(-DSD_USE_CUDA)
|
||||
if(SD_FAST_SOFTMAX)
|
||||
set(GGML_CUDA_FAST_SOFTMAX ON)
|
||||
endif()
|
||||
endif ()
|
||||
|
||||
if(SD_MUSA)
|
||||
message("-- Use MUSA as backend stable-diffusion")
|
||||
set(GGML_MUSA ON)
|
||||
add_definitions(-DSD_USE_CUDA)
|
||||
if(SD_FAST_SOFTMAX)
|
||||
set(GGML_CUDA_FAST_SOFTMAX ON)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(SD_LIB stable-diffusion)
|
||||
|
||||
25
Dockerfile.cuda
Normal file
25
Dockerfile.cuda
Normal file
@ -0,0 +1,25 @@
|
||||
ARG CUDA_VERSION=12.6.3
|
||||
ARG UBUNTU_VERSION=24.04
|
||||
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu${UBUNTU_VERSION} AS build
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends build-essential git ccache cmake
|
||||
|
||||
WORKDIR /sd.cpp
|
||||
|
||||
COPY . .
|
||||
|
||||
ARG CUDACXX=/usr/local/cuda/bin/nvcc
|
||||
RUN cmake . -B ./build -DSD_CUDA=ON
|
||||
RUN cmake --build ./build --config Release -j$(nproc)
|
||||
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-cudnn-runtime-ubuntu${UBUNTU_VERSION} AS runtime
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install --yes --no-install-recommends libgomp1 && \
|
||||
apt-get clean
|
||||
|
||||
COPY --from=build /sd.cpp/build/bin/sd-cli /sd-cli
|
||||
COPY --from=build /sd.cpp/build/bin/sd-server /sd-server
|
||||
|
||||
ENTRYPOINT [ "/sd-cli" ]
|
||||
@ -5,6 +5,7 @@
|
||||
- Download Anima
|
||||
- safetensors: https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/diffusion_models
|
||||
- gguf: https://huggingface.co/Bedovyy/Anima-GGUF/tree/main
|
||||
- gguf Anima2: https://huggingface.co/JusteLeo/Anima2-GGUF/tree/main
|
||||
- Download vae
|
||||
- safetensors: https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/vae
|
||||
- Download Qwen3-0.6B-Base
|
||||
|
||||
@ -80,7 +80,7 @@ Uses Taylor series approximation to predict block outputs:
|
||||
Combines DBCache and TaylorSeer:
|
||||
|
||||
```bash
|
||||
--cache-mode cache-dit --cache-preset fast
|
||||
--cache-mode cache-dit
|
||||
```
|
||||
|
||||
#### Parameters
|
||||
@ -92,14 +92,6 @@ Combines DBCache and TaylorSeer:
|
||||
| `threshold` | L1 residual difference threshold | 0.08 |
|
||||
| `warmup` | Steps before caching starts | 8 |
|
||||
|
||||
#### Presets
|
||||
|
||||
Available presets: `slow`, `medium`, `fast`, `ultra` (or `s`, `m`, `f`, `u`).
|
||||
|
||||
```bash
|
||||
--cache-mode cache-dit --cache-preset fast
|
||||
```
|
||||
|
||||
#### SCM Options
|
||||
|
||||
Steps Computation Mask controls which steps can be cached:
|
||||
|
||||
@ -139,12 +139,11 @@ Generation Options:
|
||||
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
|
||||
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)
|
||||
--cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level),
|
||||
'spectrum' (UNET Chebyshev+Taylor forecasting)
|
||||
'spectrum' (UNET/DiT Chebyshev+Taylor forecasting)
|
||||
--cache-option named cache params (key=value format, comma-separated). easycache/ucache:
|
||||
threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=;
|
||||
spectrum: w=,m=,lam=,window=,flex=,warmup=,stop=. Examples:
|
||||
"threshold=0.25" or "threshold=1.5,reset=0" or "w=0.4,window=2"
|
||||
--cache-preset cache-dit preset: 'slow'/'s', 'medium'/'m', 'fast'/'f', 'ultra'/'u'
|
||||
--scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g., "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache
|
||||
--scm-policy SCM policy: 'dynamic' (default) or 'static'
|
||||
```
|
||||
|
||||
@ -1047,7 +1047,6 @@ struct SDGenerationParams {
|
||||
|
||||
std::string cache_mode;
|
||||
std::string cache_option;
|
||||
std::string cache_preset;
|
||||
std::string scm_mask;
|
||||
bool scm_policy_dynamic = true;
|
||||
sd_cache_params_t cache_params{};
|
||||
@ -1461,21 +1460,6 @@ struct SDGenerationParams {
|
||||
return 1;
|
||||
};
|
||||
|
||||
auto on_cache_preset_arg = [&](int argc, const char** argv, int index) {
|
||||
if (++index >= argc) {
|
||||
return -1;
|
||||
}
|
||||
cache_preset = argv_to_utf8(index, argv);
|
||||
if (cache_preset != "slow" && cache_preset != "s" && cache_preset != "S" &&
|
||||
cache_preset != "medium" && cache_preset != "m" && cache_preset != "M" &&
|
||||
cache_preset != "fast" && cache_preset != "f" && cache_preset != "F" &&
|
||||
cache_preset != "ultra" && cache_preset != "u" && cache_preset != "U") {
|
||||
fprintf(stderr, "error: invalid cache preset '%s', must be 'slow'/'s', 'medium'/'m', 'fast'/'f', or 'ultra'/'u'\n", cache_preset.c_str());
|
||||
return -1;
|
||||
}
|
||||
return 1;
|
||||
};
|
||||
|
||||
options.manual_options = {
|
||||
{"-s",
|
||||
"--seed",
|
||||
@ -1513,16 +1497,12 @@ struct SDGenerationParams {
|
||||
on_ref_image_arg},
|
||||
{"",
|
||||
"--cache-mode",
|
||||
"caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level)",
|
||||
"caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level), 'spectrum' (UNET/DiT Chebyshev+Taylor forecasting)",
|
||||
on_cache_mode_arg},
|
||||
{"",
|
||||
"--cache-option",
|
||||
"named cache params (key=value format, comma-separated). easycache/ucache: threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=. Examples: \"threshold=0.25\" or \"threshold=1.5,reset=0\"",
|
||||
"named cache params (key=value format, comma-separated). easycache/ucache: threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=; spectrum: w=,m=,lam=,window=,flex=,warmup=,stop=. Examples: \"threshold=0.25\" or \"threshold=1.5,reset=0\"",
|
||||
on_cache_option_arg},
|
||||
{"",
|
||||
"--cache-preset",
|
||||
"cache-dit preset: 'slow'/'s', 'medium'/'m', 'fast'/'f', 'ultra'/'u'",
|
||||
on_cache_preset_arg},
|
||||
{"",
|
||||
"--scm-mask",
|
||||
"SCM steps mask for cache-dit: comma-separated 0/1 (e.g., \"1,1,1,0,0,1,0,0,1,0\") - 1=compute, 0=can cache",
|
||||
@ -1575,7 +1555,6 @@ struct SDGenerationParams {
|
||||
load_if_exists("negative_prompt", negative_prompt);
|
||||
load_if_exists("cache_mode", cache_mode);
|
||||
load_if_exists("cache_option", cache_option);
|
||||
load_if_exists("cache_preset", cache_preset);
|
||||
load_if_exists("scm_mask", scm_mask);
|
||||
|
||||
load_if_exists("clip_skip", clip_skip);
|
||||
@ -1811,47 +1790,16 @@ struct SDGenerationParams {
|
||||
if (!cache_mode.empty()) {
|
||||
if (cache_mode == "easycache") {
|
||||
cache_params.mode = SD_CACHE_EASYCACHE;
|
||||
cache_params.reuse_threshold = 0.2f;
|
||||
cache_params.start_percent = 0.15f;
|
||||
cache_params.end_percent = 0.95f;
|
||||
cache_params.error_decay_rate = 1.0f;
|
||||
cache_params.use_relative_threshold = true;
|
||||
cache_params.reset_error_on_compute = true;
|
||||
} else if (cache_mode == "ucache") {
|
||||
cache_params.mode = SD_CACHE_UCACHE;
|
||||
cache_params.reuse_threshold = 1.0f;
|
||||
cache_params.start_percent = 0.15f;
|
||||
cache_params.end_percent = 0.95f;
|
||||
cache_params.error_decay_rate = 1.0f;
|
||||
cache_params.use_relative_threshold = true;
|
||||
cache_params.reset_error_on_compute = true;
|
||||
} else if (cache_mode == "dbcache") {
|
||||
cache_params.mode = SD_CACHE_DBCACHE;
|
||||
cache_params.Fn_compute_blocks = 8;
|
||||
cache_params.Bn_compute_blocks = 0;
|
||||
cache_params.residual_diff_threshold = 0.08f;
|
||||
cache_params.max_warmup_steps = 8;
|
||||
} else if (cache_mode == "taylorseer") {
|
||||
cache_params.mode = SD_CACHE_TAYLORSEER;
|
||||
cache_params.Fn_compute_blocks = 8;
|
||||
cache_params.Bn_compute_blocks = 0;
|
||||
cache_params.residual_diff_threshold = 0.08f;
|
||||
cache_params.max_warmup_steps = 8;
|
||||
} else if (cache_mode == "cache-dit") {
|
||||
cache_params.mode = SD_CACHE_CACHE_DIT;
|
||||
cache_params.Fn_compute_blocks = 8;
|
||||
cache_params.Bn_compute_blocks = 0;
|
||||
cache_params.residual_diff_threshold = 0.08f;
|
||||
cache_params.max_warmup_steps = 8;
|
||||
} else if (cache_mode == "spectrum") {
|
||||
cache_params.mode = SD_CACHE_SPECTRUM;
|
||||
cache_params.spectrum_w = 0.40f;
|
||||
cache_params.spectrum_m = 3;
|
||||
cache_params.spectrum_lam = 1.0f;
|
||||
cache_params.spectrum_window_size = 2;
|
||||
cache_params.spectrum_flex_window = 0.50f;
|
||||
cache_params.spectrum_warmup_steps = 4;
|
||||
cache_params.spectrum_stop_percent = 0.9f;
|
||||
}
|
||||
|
||||
if (!cache_option.empty()) {
|
||||
|
||||
@ -129,11 +129,10 @@ Default Generation Options:
|
||||
--skip-layers layers to skip for SLG steps (default: [7,8,9])
|
||||
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
|
||||
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)
|
||||
--cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level)
|
||||
--cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level), 'spectrum' (UNET/DiT Chebyshev+Taylor forecasting)
|
||||
--cache-option named cache params (key=value format, comma-separated). easycache/ucache:
|
||||
threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=. Examples:
|
||||
"threshold=0.25" or "threshold=1.5,reset=0"
|
||||
--cache-preset cache-dit preset: 'slow'/'s', 'medium'/'m', 'fast'/'f', 'ultra'/'u'
|
||||
--scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g., "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache
|
||||
--scm-policy SCM policy: 'dynamic' (default) or 'static'
|
||||
```
|
||||
|
||||
930
src/auto_encoder_kl.hpp
Normal file
930
src/auto_encoder_kl.hpp
Normal file
@ -0,0 +1,930 @@
|
||||
#ifndef __AUTO_ENCODER_KL_HPP__
|
||||
#define __AUTO_ENCODER_KL_HPP__
|
||||
|
||||
#include "vae.hpp"
|
||||
|
||||
/*================================================== AutoEncoderKL ===================================================*/
|
||||
|
||||
#define VAE_GRAPH_SIZE 20480
|
||||
|
||||
class ResnetBlock : public UnaryBlock {
|
||||
protected:
|
||||
int64_t in_channels;
|
||||
int64_t out_channels;
|
||||
|
||||
public:
|
||||
ResnetBlock(int64_t in_channels,
|
||||
int64_t out_channels)
|
||||
: in_channels(in_channels),
|
||||
out_channels(out_channels) {
|
||||
// temb_channels is always 0
|
||||
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
|
||||
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
|
||||
|
||||
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(out_channels));
|
||||
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
|
||||
|
||||
if (out_channels != in_channels) {
|
||||
blocks["nin_shortcut"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, {1, 1}));
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||
// x: [N, in_channels, h, w]
|
||||
// t_emb is always None
|
||||
auto norm1 = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm1"]);
|
||||
auto conv1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv1"]);
|
||||
auto norm2 = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm2"]);
|
||||
auto conv2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv2"]);
|
||||
|
||||
auto h = x;
|
||||
h = norm1->forward(ctx, h);
|
||||
h = ggml_silu_inplace(ctx->ggml_ctx, h); // swish
|
||||
h = conv1->forward(ctx, h);
|
||||
// return h;
|
||||
|
||||
h = norm2->forward(ctx, h);
|
||||
h = ggml_silu_inplace(ctx->ggml_ctx, h); // swish
|
||||
// dropout, skip for inference
|
||||
h = conv2->forward(ctx, h);
|
||||
|
||||
// skip connection
|
||||
if (out_channels != in_channels) {
|
||||
auto nin_shortcut = std::dynamic_pointer_cast<Conv2d>(blocks["nin_shortcut"]);
|
||||
|
||||
x = nin_shortcut->forward(ctx, x); // [N, out_channels, h, w]
|
||||
}
|
||||
|
||||
h = ggml_add(ctx->ggml_ctx, h, x);
|
||||
return h; // [N, out_channels, h, w]
|
||||
}
|
||||
};
|
||||
|
||||
class AttnBlock : public UnaryBlock {
|
||||
protected:
|
||||
int64_t in_channels;
|
||||
bool use_linear;
|
||||
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") {
|
||||
auto iter = tensor_storage_map.find(prefix + "proj_out.weight");
|
||||
if (iter != tensor_storage_map.end()) {
|
||||
if (iter->second.n_dims == 4 && use_linear) {
|
||||
use_linear = false;
|
||||
blocks["q"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
|
||||
blocks["k"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
|
||||
blocks["v"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
|
||||
blocks["proj_out"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
|
||||
} else if (iter->second.n_dims == 2 && !use_linear) {
|
||||
use_linear = true;
|
||||
blocks["q"] = std::make_shared<Linear>(in_channels, in_channels);
|
||||
blocks["k"] = std::make_shared<Linear>(in_channels, in_channels);
|
||||
blocks["v"] = std::make_shared<Linear>(in_channels, in_channels);
|
||||
blocks["proj_out"] = std::make_shared<Linear>(in_channels, in_channels);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
AttnBlock(int64_t in_channels, bool use_linear)
|
||||
: in_channels(in_channels), use_linear(use_linear) {
|
||||
blocks["norm"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
|
||||
if (use_linear) {
|
||||
blocks["q"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
|
||||
blocks["k"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
|
||||
blocks["v"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
|
||||
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
|
||||
} else {
|
||||
blocks["q"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
|
||||
blocks["k"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
|
||||
blocks["v"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
|
||||
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||
// x: [N, in_channels, h, w]
|
||||
auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]);
|
||||
auto q_proj = std::dynamic_pointer_cast<UnaryBlock>(blocks["q"]);
|
||||
auto k_proj = std::dynamic_pointer_cast<UnaryBlock>(blocks["k"]);
|
||||
auto v_proj = std::dynamic_pointer_cast<UnaryBlock>(blocks["v"]);
|
||||
auto proj_out = std::dynamic_pointer_cast<UnaryBlock>(blocks["proj_out"]);
|
||||
|
||||
auto h_ = norm->forward(ctx, x);
|
||||
|
||||
const int64_t n = h_->ne[3];
|
||||
const int64_t c = h_->ne[2];
|
||||
const int64_t h = h_->ne[1];
|
||||
const int64_t w = h_->ne[0];
|
||||
|
||||
ggml_tensor* q;
|
||||
ggml_tensor* k;
|
||||
ggml_tensor* v;
|
||||
if (use_linear) {
|
||||
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||
h_ = ggml_reshape_3d(ctx->ggml_ctx, h_, c, h * w, n); // [N, h * w, in_channels]
|
||||
|
||||
q = q_proj->forward(ctx, h_); // [N, h * w, in_channels]
|
||||
k = k_proj->forward(ctx, h_); // [N, h * w, in_channels]
|
||||
v = v_proj->forward(ctx, h_); // [N, h * w, in_channels]
|
||||
} else {
|
||||
q = q_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
||||
q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||
q = ggml_reshape_3d(ctx->ggml_ctx, q, c, h * w, n); // [N, h * w, in_channels]
|
||||
|
||||
k = k_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
||||
k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||
k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [N, h * w, in_channels]
|
||||
|
||||
v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
||||
v = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, v, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||
v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels]
|
||||
}
|
||||
|
||||
h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, ctx->flash_attn_enabled);
|
||||
|
||||
if (use_linear) {
|
||||
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]
|
||||
|
||||
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
|
||||
h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w]
|
||||
} else {
|
||||
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
|
||||
h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w]
|
||||
|
||||
h_ = proj_out->forward(ctx, h_); // [N, in_channels, h, w]
|
||||
}
|
||||
|
||||
h_ = ggml_add(ctx->ggml_ctx, h_, x);
|
||||
return h_;
|
||||
}
|
||||
};
|
||||
|
||||
class AE3DConv : public Conv2d {
|
||||
public:
|
||||
AE3DConv(int64_t in_channels,
|
||||
int64_t out_channels,
|
||||
std::pair<int, int> kernel_size,
|
||||
int video_kernel_size = 3,
|
||||
std::pair<int, int> stride = {1, 1},
|
||||
std::pair<int, int> padding = {0, 0},
|
||||
std::pair<int, int> dilation = {1, 1},
|
||||
bool bias = true)
|
||||
: Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias) {
|
||||
int kernel_padding = video_kernel_size / 2;
|
||||
blocks["time_mix_conv"] = std::shared_ptr<GGMLBlock>(new Conv3d(out_channels,
|
||||
out_channels,
|
||||
{video_kernel_size, 1, 1},
|
||||
{1, 1, 1},
|
||||
{kernel_padding, 0, 0}));
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||
struct ggml_tensor* x) override {
|
||||
// timesteps always None
|
||||
// skip_video always False
|
||||
// x: [N, IC, IH, IW]
|
||||
// result: [N, OC, OH, OW]
|
||||
auto time_mix_conv = std::dynamic_pointer_cast<Conv3d>(blocks["time_mix_conv"]);
|
||||
|
||||
x = Conv2d::forward(ctx, x);
|
||||
// timesteps = x.shape[0]
|
||||
// x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
||||
// x = conv3d(x)
|
||||
// return rearrange(x, "b c t h w -> (b t) c h w")
|
||||
int64_t T = x->ne[3];
|
||||
int64_t B = x->ne[3] / T;
|
||||
int64_t C = x->ne[2];
|
||||
int64_t H = x->ne[1];
|
||||
int64_t W = x->ne[0];
|
||||
|
||||
x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
|
||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
|
||||
x = time_mix_conv->forward(ctx, x); // [B, OC, T, OH * OW]
|
||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
|
||||
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
|
||||
return x; // [B*T, OC, OH, OW]
|
||||
}
|
||||
};
|
||||
|
||||
class VideoResnetBlock : public ResnetBlock {
|
||||
protected:
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
|
||||
enum ggml_type wtype = get_type(prefix + "mix_factor", tensor_storage_map, GGML_TYPE_F32);
|
||||
params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1);
|
||||
}
|
||||
|
||||
float get_alpha() {
|
||||
float alpha = ggml_ext_backend_tensor_get_f32(params["mix_factor"]);
|
||||
return sigmoid(alpha);
|
||||
}
|
||||
|
||||
public:
|
||||
VideoResnetBlock(int64_t in_channels,
|
||||
int64_t out_channels,
|
||||
int video_kernel_size = 3)
|
||||
: ResnetBlock(in_channels, out_channels) {
|
||||
// merge_strategy is always learned
|
||||
blocks["time_stack"] = std::shared_ptr<GGMLBlock>(new ResBlock(out_channels, 0, out_channels, {video_kernel_size, 1}, 3, false, true));
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||
// x: [N, in_channels, h, w] aka [b*t, in_channels, h, w]
|
||||
// return: [N, out_channels, h, w] aka [b*t, out_channels, h, w]
|
||||
// t_emb is always None
|
||||
// skip_video is always False
|
||||
// timesteps is always None
|
||||
auto time_stack = std::dynamic_pointer_cast<ResBlock>(blocks["time_stack"]);
|
||||
|
||||
x = ResnetBlock::forward(ctx, x); // [N, out_channels, h, w]
|
||||
// return x;
|
||||
|
||||
int64_t T = x->ne[3];
|
||||
int64_t B = x->ne[3] / T;
|
||||
int64_t C = x->ne[2];
|
||||
int64_t H = x->ne[1];
|
||||
int64_t W = x->ne[0];
|
||||
|
||||
x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
|
||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
|
||||
auto x_mix = x;
|
||||
|
||||
x = time_stack->forward(ctx, x); // b t c (h w)
|
||||
|
||||
float alpha = get_alpha();
|
||||
x = ggml_add(ctx->ggml_ctx,
|
||||
ggml_ext_scale(ctx->ggml_ctx, x, alpha),
|
||||
ggml_ext_scale(ctx->ggml_ctx, x_mix, 1.0f - alpha));
|
||||
|
||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
|
||||
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
|
||||
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
// ldm.modules.diffusionmodules.model.Encoder
|
||||
class Encoder : public GGMLBlock {
|
||||
protected:
|
||||
int ch = 128;
|
||||
std::vector<int> ch_mult = {1, 2, 4, 4};
|
||||
int num_res_blocks = 2;
|
||||
int in_channels = 3;
|
||||
int z_channels = 4;
|
||||
bool double_z = true;
|
||||
|
||||
public:
|
||||
Encoder(int ch,
|
||||
std::vector<int> ch_mult,
|
||||
int num_res_blocks,
|
||||
int in_channels,
|
||||
int z_channels,
|
||||
bool double_z = true,
|
||||
bool use_linear_projection = false)
|
||||
: ch(ch),
|
||||
ch_mult(ch_mult),
|
||||
num_res_blocks(num_res_blocks),
|
||||
in_channels(in_channels),
|
||||
z_channels(z_channels),
|
||||
double_z(double_z) {
|
||||
blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, ch, {3, 3}, {1, 1}, {1, 1}));
|
||||
|
||||
size_t num_resolutions = ch_mult.size();
|
||||
|
||||
int block_in = 1;
|
||||
for (int i = 0; i < num_resolutions; i++) {
|
||||
if (i == 0) {
|
||||
block_in = ch;
|
||||
} else {
|
||||
block_in = ch * ch_mult[i - 1];
|
||||
}
|
||||
int block_out = ch * ch_mult[i];
|
||||
for (int j = 0; j < num_res_blocks; j++) {
|
||||
std::string name = "down." + std::to_string(i) + ".block." + std::to_string(j);
|
||||
blocks[name] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_out));
|
||||
block_in = block_out;
|
||||
}
|
||||
if (i != num_resolutions - 1) {
|
||||
std::string name = "down." + std::to_string(i) + ".downsample";
|
||||
blocks[name] = std::shared_ptr<GGMLBlock>(new DownSampleBlock(block_in, block_in, true));
|
||||
}
|
||||
}
|
||||
|
||||
blocks["mid.block_1"] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_in));
|
||||
blocks["mid.attn_1"] = std::shared_ptr<GGMLBlock>(new AttnBlock(block_in, use_linear_projection));
|
||||
blocks["mid.block_2"] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_in));
|
||||
|
||||
blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(block_in));
|
||||
blocks["conv_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(block_in, double_z ? z_channels * 2 : z_channels, {3, 3}, {1, 1}, {1, 1}));
|
||||
}
|
||||
|
||||
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||
// x: [N, in_channels, h, w]
|
||||
|
||||
auto conv_in = std::dynamic_pointer_cast<Conv2d>(blocks["conv_in"]);
|
||||
auto mid_block_1 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_1"]);
|
||||
auto mid_attn_1 = std::dynamic_pointer_cast<AttnBlock>(blocks["mid.attn_1"]);
|
||||
auto mid_block_2 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_2"]);
|
||||
auto norm_out = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm_out"]);
|
||||
auto conv_out = std::dynamic_pointer_cast<Conv2d>(blocks["conv_out"]);
|
||||
|
||||
auto h = conv_in->forward(ctx, x); // [N, ch, h, w]
|
||||
|
||||
// downsampling
|
||||
size_t num_resolutions = ch_mult.size();
|
||||
for (int i = 0; i < num_resolutions; i++) {
|
||||
for (int j = 0; j < num_res_blocks; j++) {
|
||||
std::string name = "down." + std::to_string(i) + ".block." + std::to_string(j);
|
||||
auto down_block = std::dynamic_pointer_cast<ResnetBlock>(blocks[name]);
|
||||
|
||||
h = down_block->forward(ctx, h);
|
||||
}
|
||||
if (i != num_resolutions - 1) {
|
||||
std::string name = "down." + std::to_string(i) + ".downsample";
|
||||
auto down_sample = std::dynamic_pointer_cast<DownSampleBlock>(blocks[name]);
|
||||
|
||||
h = down_sample->forward(ctx, h);
|
||||
}
|
||||
}
|
||||
|
||||
// middle
|
||||
h = mid_block_1->forward(ctx, h);
|
||||
h = mid_attn_1->forward(ctx, h);
|
||||
h = mid_block_2->forward(ctx, h); // [N, block_in, h, w]
|
||||
|
||||
// end
|
||||
h = norm_out->forward(ctx, h);
|
||||
h = ggml_silu_inplace(ctx->ggml_ctx, h); // nonlinearity/swish
|
||||
h = conv_out->forward(ctx, h); // [N, z_channels*2, h, w]
|
||||
return h;
|
||||
}
|
||||
};
|
||||
|
||||
// ldm.modules.diffusionmodules.model.Decoder
|
||||
class Decoder : public GGMLBlock {
|
||||
protected:
|
||||
int ch = 128;
|
||||
int out_ch = 3;
|
||||
std::vector<int> ch_mult = {1, 2, 4, 4};
|
||||
int num_res_blocks = 2;
|
||||
int z_channels = 4;
|
||||
bool video_decoder = false;
|
||||
int video_kernel_size = 3;
|
||||
|
||||
virtual std::shared_ptr<GGMLBlock> get_conv_out(int64_t in_channels,
|
||||
int64_t out_channels,
|
||||
std::pair<int, int> kernel_size,
|
||||
std::pair<int, int> stride = {1, 1},
|
||||
std::pair<int, int> padding = {0, 0}) {
|
||||
if (video_decoder) {
|
||||
return std::shared_ptr<GGMLBlock>(new AE3DConv(in_channels, out_channels, kernel_size, video_kernel_size, stride, padding));
|
||||
} else {
|
||||
return std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, kernel_size, stride, padding));
|
||||
}
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<GGMLBlock> get_resnet_block(int64_t in_channels,
|
||||
int64_t out_channels) {
|
||||
if (video_decoder) {
|
||||
return std::shared_ptr<GGMLBlock>(new VideoResnetBlock(in_channels, out_channels, video_kernel_size));
|
||||
} else {
|
||||
return std::shared_ptr<GGMLBlock>(new ResnetBlock(in_channels, out_channels));
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
Decoder(int ch,
|
||||
int out_ch,
|
||||
std::vector<int> ch_mult,
|
||||
int num_res_blocks,
|
||||
int z_channels,
|
||||
bool use_linear_projection = false,
|
||||
bool video_decoder = false,
|
||||
int video_kernel_size = 3)
|
||||
: ch(ch),
|
||||
out_ch(out_ch),
|
||||
ch_mult(ch_mult),
|
||||
num_res_blocks(num_res_blocks),
|
||||
z_channels(z_channels),
|
||||
video_decoder(video_decoder),
|
||||
video_kernel_size(video_kernel_size) {
|
||||
int num_resolutions = static_cast<int>(ch_mult.size());
|
||||
int block_in = ch * ch_mult[num_resolutions - 1];
|
||||
|
||||
blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, block_in, {3, 3}, {1, 1}, {1, 1}));
|
||||
|
||||
blocks["mid.block_1"] = get_resnet_block(block_in, block_in);
|
||||
blocks["mid.attn_1"] = std::shared_ptr<GGMLBlock>(new AttnBlock(block_in, use_linear_projection));
|
||||
blocks["mid.block_2"] = get_resnet_block(block_in, block_in);
|
||||
|
||||
for (int i = num_resolutions - 1; i >= 0; i--) {
|
||||
int mult = ch_mult[i];
|
||||
int block_out = ch * mult;
|
||||
for (int j = 0; j < num_res_blocks + 1; j++) {
|
||||
std::string name = "up." + std::to_string(i) + ".block." + std::to_string(j);
|
||||
blocks[name] = get_resnet_block(block_in, block_out);
|
||||
|
||||
block_in = block_out;
|
||||
}
|
||||
if (i != 0) {
|
||||
std::string name = "up." + std::to_string(i) + ".upsample";
|
||||
blocks[name] = std::shared_ptr<GGMLBlock>(new UpSampleBlock(block_in, block_in));
|
||||
}
|
||||
}
|
||||
|
||||
blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(block_in));
|
||||
blocks["conv_out"] = get_conv_out(block_in, out_ch, {3, 3}, {1, 1}, {1, 1});
|
||||
}
|
||||
|
||||
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
|
||||
// z: [N, z_channels, h, w]
|
||||
// alpha is always 0
|
||||
// merge_strategy is always learned
|
||||
// time_mode is always conv-only, so we need to replace conv_out_op/resnet_op to AE3DConv/VideoResBlock
|
||||
// AttnVideoBlock will not be used
|
||||
auto conv_in = std::dynamic_pointer_cast<Conv2d>(blocks["conv_in"]);
|
||||
auto mid_block_1 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_1"]);
|
||||
auto mid_attn_1 = std::dynamic_pointer_cast<AttnBlock>(blocks["mid.attn_1"]);
|
||||
auto mid_block_2 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_2"]);
|
||||
auto norm_out = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm_out"]);
|
||||
auto conv_out = std::dynamic_pointer_cast<Conv2d>(blocks["conv_out"]);
|
||||
|
||||
// conv_in
|
||||
auto h = conv_in->forward(ctx, z); // [N, block_in, h, w]
|
||||
|
||||
// middle
|
||||
h = mid_block_1->forward(ctx, h);
|
||||
// return h;
|
||||
|
||||
h = mid_attn_1->forward(ctx, h);
|
||||
h = mid_block_2->forward(ctx, h); // [N, block_in, h, w]
|
||||
|
||||
// upsampling
|
||||
int num_resolutions = static_cast<int>(ch_mult.size());
|
||||
for (int i = num_resolutions - 1; i >= 0; i--) {
|
||||
for (int j = 0; j < num_res_blocks + 1; j++) {
|
||||
std::string name = "up." + std::to_string(i) + ".block." + std::to_string(j);
|
||||
auto up_block = std::dynamic_pointer_cast<ResnetBlock>(blocks[name]);
|
||||
|
||||
h = up_block->forward(ctx, h);
|
||||
}
|
||||
if (i != 0) {
|
||||
std::string name = "up." + std::to_string(i) + ".upsample";
|
||||
auto up_sample = std::dynamic_pointer_cast<UpSampleBlock>(blocks[name]);
|
||||
|
||||
h = up_sample->forward(ctx, h);
|
||||
}
|
||||
}
|
||||
|
||||
h = norm_out->forward(ctx, h);
|
||||
h = ggml_silu_inplace(ctx->ggml_ctx, h); // nonlinearity/swish
|
||||
h = conv_out->forward(ctx, h); // [N, out_ch, h*8, w*8]
|
||||
return h;
|
||||
}
|
||||
};
|
||||
|
||||
// ldm.models.autoencoder.AutoencoderKL
|
||||
class AutoEncoderKLModel : public GGMLBlock {
|
||||
protected:
|
||||
SDVersion version;
|
||||
bool decode_only = true;
|
||||
bool use_video_decoder = false;
|
||||
bool use_quant = true;
|
||||
int embed_dim = 4;
|
||||
struct {
|
||||
int z_channels = 4;
|
||||
int resolution = 256;
|
||||
int in_channels = 3;
|
||||
int out_ch = 3;
|
||||
int ch = 128;
|
||||
std::vector<int> ch_mult = {1, 2, 4, 4};
|
||||
int num_res_blocks = 2;
|
||||
bool double_z = true;
|
||||
} dd_config;
|
||||
|
||||
public:
|
||||
AutoEncoderKLModel(SDVersion version = VERSION_SD1,
|
||||
bool decode_only = true,
|
||||
bool use_linear_projection = false,
|
||||
bool use_video_decoder = false)
|
||||
: version(version), decode_only(decode_only), use_video_decoder(use_video_decoder) {
|
||||
if (sd_version_is_dit(version)) {
|
||||
if (sd_version_is_flux2(version)) {
|
||||
dd_config.z_channels = 32;
|
||||
embed_dim = 32;
|
||||
} else {
|
||||
use_quant = false;
|
||||
dd_config.z_channels = 16;
|
||||
}
|
||||
}
|
||||
if (use_video_decoder) {
|
||||
use_quant = false;
|
||||
}
|
||||
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder(dd_config.ch,
|
||||
dd_config.out_ch,
|
||||
dd_config.ch_mult,
|
||||
dd_config.num_res_blocks,
|
||||
dd_config.z_channels,
|
||||
use_linear_projection,
|
||||
use_video_decoder));
|
||||
if (use_quant) {
|
||||
blocks["post_quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(dd_config.z_channels,
|
||||
embed_dim,
|
||||
{1, 1}));
|
||||
}
|
||||
if (!decode_only) {
|
||||
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new Encoder(dd_config.ch,
|
||||
dd_config.ch_mult,
|
||||
dd_config.num_res_blocks,
|
||||
dd_config.in_channels,
|
||||
dd_config.z_channels,
|
||||
dd_config.double_z,
|
||||
use_linear_projection));
|
||||
if (use_quant) {
|
||||
int factor = dd_config.double_z ? 2 : 1;
|
||||
|
||||
blocks["quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(embed_dim * factor,
|
||||
dd_config.z_channels * factor,
|
||||
{1, 1}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
|
||||
// z: [N, z_channels, h, w]
|
||||
if (sd_version_is_flux2(version)) {
|
||||
// [N, C*p*p, h, w] -> [N, C, h*p, w*p]
|
||||
int64_t p = 2;
|
||||
|
||||
int64_t N = z->ne[3];
|
||||
int64_t C = z->ne[2] / p / p;
|
||||
int64_t h = z->ne[1];
|
||||
int64_t w = z->ne[0];
|
||||
int64_t H = h * p;
|
||||
int64_t W = w * p;
|
||||
|
||||
z = ggml_reshape_4d(ctx->ggml_ctx, z, w * h, p * p, C, N); // [N, C, p*p, h*w]
|
||||
z = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, z, 1, 0, 2, 3)); // [N, C, h*w, p*p]
|
||||
z = ggml_reshape_4d(ctx->ggml_ctx, z, p, p, w, h * C * N); // [N*C*h, w, p, p]
|
||||
z = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, z, 0, 2, 1, 3)); // [N*C*h, p, w, p]
|
||||
z = ggml_reshape_4d(ctx->ggml_ctx, z, W, H, C, N); // [N, C, h*p, w*p]
|
||||
}
|
||||
|
||||
if (use_quant) {
|
||||
auto post_quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["post_quant_conv"]);
|
||||
z = post_quant_conv->forward(ctx, z); // [N, z_channels, h, w]
|
||||
}
|
||||
auto decoder = std::dynamic_pointer_cast<Decoder>(blocks["decoder"]);
|
||||
|
||||
ggml_set_name(z, "bench-start");
|
||||
auto h = decoder->forward(ctx, z);
|
||||
ggml_set_name(h, "bench-end");
|
||||
return h;
|
||||
}
|
||||
|
||||
struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||
// x: [N, in_channels, h, w]
|
||||
auto encoder = std::dynamic_pointer_cast<Encoder>(blocks["encoder"]);
|
||||
|
||||
auto z = encoder->forward(ctx, x); // [N, 2*z_channels, h/8, w/8]
|
||||
if (use_quant) {
|
||||
auto quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["quant_conv"]);
|
||||
z = quant_conv->forward(ctx, z); // [N, 2*embed_dim, h/8, w/8]
|
||||
}
|
||||
if (sd_version_is_flux2(version)) {
|
||||
z = ggml_ext_chunk(ctx->ggml_ctx, z, 2, 2)[0];
|
||||
|
||||
// [N, C, H, W] -> [N, C*p*p, H/p, W/p]
|
||||
int64_t p = 2;
|
||||
int64_t N = z->ne[3];
|
||||
int64_t C = z->ne[2];
|
||||
int64_t H = z->ne[1];
|
||||
int64_t W = z->ne[0];
|
||||
int64_t h = H / p;
|
||||
int64_t w = W / p;
|
||||
|
||||
z = ggml_reshape_4d(ctx->ggml_ctx, z, p, w, p, h * C * N); // [N*C*h, p, w, p]
|
||||
z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 0, 2, 1, 3)); // [N*C*h, w, p, p]
|
||||
z = ggml_reshape_4d(ctx->ggml_ctx, z, p * p, w * h, C, N); // [N, C, h*w, p*p]
|
||||
z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 1, 0, 2, 3)); // [N, C, p*p, h*w]
|
||||
z = ggml_reshape_4d(ctx->ggml_ctx, z, w, h, p * p * C, N); // [N, C*p*p, h*w]
|
||||
}
|
||||
return z;
|
||||
}
|
||||
|
||||
int get_encoder_output_channels() {
|
||||
int factor = dd_config.double_z ? 2 : 1;
|
||||
return dd_config.z_channels * factor;
|
||||
}
|
||||
};
|
||||
|
||||
struct AutoEncoderKL : public VAE {
|
||||
float scale_factor = 1.f;
|
||||
float shift_factor = 0.f;
|
||||
bool decode_only = true;
|
||||
AutoEncoderKLModel ae;
|
||||
|
||||
AutoEncoderKL(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2TensorStorage& tensor_storage_map,
|
||||
const std::string prefix,
|
||||
bool decode_only = false,
|
||||
bool use_video_decoder = false,
|
||||
SDVersion version = VERSION_SD1)
|
||||
: decode_only(decode_only), VAE(version, backend, offload_params_to_cpu) {
|
||||
if (sd_version_is_sd1(version) || sd_version_is_sd2(version)) {
|
||||
scale_factor = 0.18215f;
|
||||
shift_factor = 0.f;
|
||||
} else if (sd_version_is_sdxl(version)) {
|
||||
scale_factor = 0.13025f;
|
||||
shift_factor = 0.f;
|
||||
} else if (sd_version_is_sd3(version)) {
|
||||
scale_factor = 1.5305f;
|
||||
shift_factor = 0.0609f;
|
||||
} else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) {
|
||||
scale_factor = 0.3611f;
|
||||
shift_factor = 0.1159f;
|
||||
} else if (sd_version_is_flux2(version)) {
|
||||
scale_factor = 1.0f;
|
||||
shift_factor = 0.f;
|
||||
}
|
||||
bool use_linear_projection = false;
|
||||
for (const auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (!starts_with(name, prefix)) {
|
||||
continue;
|
||||
}
|
||||
if (ends_with(name, "attn_1.proj_out.weight")) {
|
||||
if (tensor_storage.n_dims == 2) {
|
||||
use_linear_projection = true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
ae = AutoEncoderKLModel(version, decode_only, use_linear_projection, use_video_decoder);
|
||||
ae.init(params_ctx, tensor_storage_map, prefix);
|
||||
}
|
||||
|
||||
void set_conv2d_scale(float scale) override {
|
||||
std::vector<GGMLBlock*> blocks;
|
||||
ae.get_all_blocks(blocks);
|
||||
for (auto block : blocks) {
|
||||
if (block->get_desc() == "Conv2d") {
|
||||
auto conv_block = (Conv2d*)block;
|
||||
conv_block->set_scale(scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
return "vae";
|
||||
}
|
||||
|
||||
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) override {
|
||||
ae.get_param_tensors(tensors, prefix);
|
||||
}
|
||||
|
||||
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
|
||||
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||
|
||||
z = to_backend(z);
|
||||
|
||||
auto runner_ctx = get_context();
|
||||
|
||||
struct ggml_tensor* out = decode_graph ? ae.decode(&runner_ctx, z) : ae.encode(&runner_ctx, z);
|
||||
|
||||
ggml_build_forward_expand(gf, out);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
bool _compute(const int n_threads,
|
||||
struct ggml_tensor* z,
|
||||
bool decode_graph,
|
||||
struct ggml_tensor** output,
|
||||
struct ggml_context* output_ctx = nullptr) override {
|
||||
GGML_ASSERT(!decode_only || decode_graph);
|
||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||
return build_graph(z, decode_graph);
|
||||
};
|
||||
// ggml_set_f32(z, 0.5f);
|
||||
// print_ggml_tensor(z);
|
||||
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
}
|
||||
|
||||
ggml_tensor* gaussian_latent_sample(ggml_context* work_ctx, ggml_tensor* moments, std::shared_ptr<RNG> rng) {
|
||||
// ldm.modules.distributions.distributions.DiagonalGaussianDistribution.sample
|
||||
ggml_tensor* latents = ggml_new_tensor_4d(work_ctx, moments->type, moments->ne[0], moments->ne[1], moments->ne[2] / 2, moments->ne[3]);
|
||||
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, latents);
|
||||
ggml_ext_im_set_randn_f32(noise, rng);
|
||||
{
|
||||
float mean = 0;
|
||||
float logvar = 0;
|
||||
float value = 0;
|
||||
float std_ = 0;
|
||||
for (int i = 0; i < latents->ne[3]; i++) {
|
||||
for (int j = 0; j < latents->ne[2]; j++) {
|
||||
for (int k = 0; k < latents->ne[1]; k++) {
|
||||
for (int l = 0; l < latents->ne[0]; l++) {
|
||||
mean = ggml_ext_tensor_get_f32(moments, l, k, j, i);
|
||||
logvar = ggml_ext_tensor_get_f32(moments, l, k, j + (int)latents->ne[2], i);
|
||||
logvar = std::max(-30.0f, std::min(logvar, 20.0f));
|
||||
std_ = std::exp(0.5f * logvar);
|
||||
value = mean + std_ * ggml_ext_tensor_get_f32(noise, l, k, j, i);
|
||||
// printf("%d %d %d %d -> %f\n", i, j, k, l, value);
|
||||
ggml_ext_tensor_set_f32(latents, value, l, k, j, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return latents;
|
||||
}
|
||||
|
||||
ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr<RNG> rng) {
|
||||
if (sd_version_is_flux2(version)) {
|
||||
return vae_output;
|
||||
} else if (version == VERSION_SD1_PIX2PIX) {
|
||||
return ggml_view_3d(work_ctx,
|
||||
vae_output,
|
||||
vae_output->ne[0],
|
||||
vae_output->ne[1],
|
||||
vae_output->ne[2] / 2,
|
||||
vae_output->nb[1],
|
||||
vae_output->nb[2],
|
||||
0);
|
||||
} else {
|
||||
return gaussian_latent_sample(work_ctx, vae_output, rng);
|
||||
}
|
||||
}
|
||||
|
||||
void get_latents_mean_std_vec(ggml_tensor* latents, int channel_dim, std::vector<float>& latents_mean_vec, std::vector<float>& latents_std_vec) {
|
||||
// flux2
|
||||
if (sd_version_is_flux2(version)) {
|
||||
GGML_ASSERT(latents->ne[channel_dim] == 128);
|
||||
latents_mean_vec = {-0.0676f, -0.0715f, -0.0753f, -0.0745f, 0.0223f, 0.0180f, 0.0142f, 0.0184f,
|
||||
-0.0001f, -0.0063f, -0.0002f, -0.0031f, -0.0272f, -0.0281f, -0.0276f, -0.0290f,
|
||||
-0.0769f, -0.0672f, -0.0902f, -0.0892f, 0.0168f, 0.0152f, 0.0079f, 0.0086f,
|
||||
0.0083f, 0.0015f, 0.0003f, -0.0043f, -0.0439f, -0.0419f, -0.0438f, -0.0431f,
|
||||
-0.0102f, -0.0132f, -0.0066f, -0.0048f, -0.0311f, -0.0306f, -0.0279f, -0.0180f,
|
||||
0.0030f, 0.0015f, 0.0126f, 0.0145f, 0.0347f, 0.0338f, 0.0337f, 0.0283f,
|
||||
0.0020f, 0.0047f, 0.0047f, 0.0050f, 0.0123f, 0.0081f, 0.0081f, 0.0146f,
|
||||
0.0681f, 0.0679f, 0.0767f, 0.0732f, -0.0462f, -0.0474f, -0.0392f, -0.0511f,
|
||||
-0.0528f, -0.0477f, -0.0470f, -0.0517f, -0.0317f, -0.0316f, -0.0345f, -0.0283f,
|
||||
0.0510f, 0.0445f, 0.0578f, 0.0458f, -0.0412f, -0.0458f, -0.0487f, -0.0467f,
|
||||
-0.0088f, -0.0106f, -0.0088f, -0.0046f, -0.0376f, -0.0432f, -0.0436f, -0.0499f,
|
||||
0.0118f, 0.0166f, 0.0203f, 0.0279f, 0.0113f, 0.0129f, 0.0016f, 0.0072f,
|
||||
-0.0118f, -0.0018f, -0.0141f, -0.0054f, -0.0091f, -0.0138f, -0.0145f, -0.0187f,
|
||||
0.0323f, 0.0305f, 0.0259f, 0.0300f, 0.0540f, 0.0614f, 0.0495f, 0.0590f,
|
||||
-0.0511f, -0.0603f, -0.0478f, -0.0524f, -0.0227f, -0.0274f, -0.0154f, -0.0255f,
|
||||
-0.0572f, -0.0565f, -0.0518f, -0.0496f, 0.0116f, 0.0054f, 0.0163f, 0.0104f};
|
||||
latents_std_vec = {
|
||||
1.8029f, 1.7786f, 1.7868f, 1.7837f, 1.7717f, 1.7590f, 1.7610f, 1.7479f,
|
||||
1.7336f, 1.7373f, 1.7340f, 1.7343f, 1.8626f, 1.8527f, 1.8629f, 1.8589f,
|
||||
1.7593f, 1.7526f, 1.7556f, 1.7583f, 1.7363f, 1.7400f, 1.7355f, 1.7394f,
|
||||
1.7342f, 1.7246f, 1.7392f, 1.7304f, 1.7551f, 1.7513f, 1.7559f, 1.7488f,
|
||||
1.8449f, 1.8454f, 1.8550f, 1.8535f, 1.8240f, 1.7813f, 1.7854f, 1.7945f,
|
||||
1.8047f, 1.7876f, 1.7695f, 1.7676f, 1.7782f, 1.7667f, 1.7925f, 1.7848f,
|
||||
1.7579f, 1.7407f, 1.7483f, 1.7368f, 1.7961f, 1.7998f, 1.7920f, 1.7925f,
|
||||
1.7780f, 1.7747f, 1.7727f, 1.7749f, 1.7526f, 1.7447f, 1.7657f, 1.7495f,
|
||||
1.7775f, 1.7720f, 1.7813f, 1.7813f, 1.8162f, 1.8013f, 1.8023f, 1.8033f,
|
||||
1.7527f, 1.7331f, 1.7563f, 1.7482f, 1.7610f, 1.7507f, 1.7681f, 1.7613f,
|
||||
1.7665f, 1.7545f, 1.7828f, 1.7726f, 1.7896f, 1.7999f, 1.7864f, 1.7760f,
|
||||
1.7613f, 1.7625f, 1.7560f, 1.7577f, 1.7783f, 1.7671f, 1.7810f, 1.7799f,
|
||||
1.7201f, 1.7068f, 1.7265f, 1.7091f, 1.7793f, 1.7578f, 1.7502f, 1.7455f,
|
||||
1.7587f, 1.7500f, 1.7525f, 1.7362f, 1.7616f, 1.7572f, 1.7444f, 1.7430f,
|
||||
1.7509f, 1.7610f, 1.7634f, 1.7612f, 1.7254f, 1.7135f, 1.7321f, 1.7226f,
|
||||
1.7664f, 1.7624f, 1.7718f, 1.7664f, 1.7457f, 1.7441f, 1.7569f, 1.7530f};
|
||||
} else {
|
||||
GGML_ABORT("unknown version %d", version);
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
||||
ggml_tensor* vae_latents = ggml_dup(work_ctx, latents);
|
||||
if (sd_version_is_flux2(version)) {
|
||||
int channel_dim = 2;
|
||||
std::vector<float> latents_mean_vec;
|
||||
std::vector<float> latents_std_vec;
|
||||
get_latents_mean_std_vec(latents, channel_dim, latents_mean_vec, latents_std_vec);
|
||||
|
||||
float mean;
|
||||
float std_;
|
||||
for (int i = 0; i < latents->ne[3]; i++) {
|
||||
if (channel_dim == 3) {
|
||||
mean = latents_mean_vec[i];
|
||||
std_ = latents_std_vec[i];
|
||||
}
|
||||
for (int j = 0; j < latents->ne[2]; j++) {
|
||||
if (channel_dim == 2) {
|
||||
mean = latents_mean_vec[j];
|
||||
std_ = latents_std_vec[j];
|
||||
}
|
||||
for (int k = 0; k < latents->ne[1]; k++) {
|
||||
for (int l = 0; l < latents->ne[0]; l++) {
|
||||
float value = ggml_ext_tensor_get_f32(latents, l, k, j, i);
|
||||
value = value * std_ / scale_factor + mean;
|
||||
ggml_ext_tensor_set_f32(vae_latents, value, l, k, j, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ggml_ext_tensor_iter(latents, [&](ggml_tensor* latents, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||
float value = ggml_ext_tensor_get_f32(latents, i0, i1, i2, i3);
|
||||
value = (value / scale_factor) + shift_factor;
|
||||
ggml_ext_tensor_set_f32(vae_latents, value, i0, i1, i2, i3);
|
||||
});
|
||||
}
|
||||
return vae_latents;
|
||||
}
|
||||
|
||||
ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
||||
ggml_tensor* diffusion_latents = ggml_dup(work_ctx, latents);
|
||||
if (sd_version_is_flux2(version)) {
|
||||
int channel_dim = 2;
|
||||
std::vector<float> latents_mean_vec;
|
||||
std::vector<float> latents_std_vec;
|
||||
get_latents_mean_std_vec(latents, channel_dim, latents_mean_vec, latents_std_vec);
|
||||
|
||||
float mean;
|
||||
float std_;
|
||||
for (int i = 0; i < latents->ne[3]; i++) {
|
||||
if (channel_dim == 3) {
|
||||
mean = latents_mean_vec[i];
|
||||
std_ = latents_std_vec[i];
|
||||
}
|
||||
for (int j = 0; j < latents->ne[2]; j++) {
|
||||
if (channel_dim == 2) {
|
||||
mean = latents_mean_vec[j];
|
||||
std_ = latents_std_vec[j];
|
||||
}
|
||||
for (int k = 0; k < latents->ne[1]; k++) {
|
||||
for (int l = 0; l < latents->ne[0]; l++) {
|
||||
float value = ggml_ext_tensor_get_f32(latents, l, k, j, i);
|
||||
value = (value - mean) * scale_factor / std_;
|
||||
ggml_ext_tensor_set_f32(diffusion_latents, value, l, k, j, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ggml_ext_tensor_iter(latents, [&](ggml_tensor* latents, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||
float value = ggml_ext_tensor_get_f32(latents, i0, i1, i2, i3);
|
||||
value = (value - shift_factor) * scale_factor;
|
||||
ggml_ext_tensor_set_f32(diffusion_latents, value, i0, i1, i2, i3);
|
||||
});
|
||||
}
|
||||
return diffusion_latents;
|
||||
}
|
||||
|
||||
int get_encoder_output_channels(int input_channels) {
|
||||
return ae.get_encoder_output_channels();
|
||||
}
|
||||
|
||||
void test() {
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
|
||||
params.mem_buffer = nullptr;
|
||||
params.no_alloc = false;
|
||||
|
||||
struct ggml_context* work_ctx = ggml_init(params);
|
||||
GGML_ASSERT(work_ctx != nullptr);
|
||||
|
||||
{
|
||||
// CPU, x{1, 3, 64, 64}: Pass
|
||||
// CUDA, x{1, 3, 64, 64}: Pass, but sill get wrong result for some image, may be due to interlnal nan
|
||||
// CPU, x{2, 3, 64, 64}: Wrong result
|
||||
// CUDA, x{2, 3, 64, 64}: Wrong result, and different from CPU result
|
||||
auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 64, 64, 3, 2);
|
||||
ggml_set_f32(x, 0.5f);
|
||||
print_ggml_tensor(x);
|
||||
struct ggml_tensor* out = nullptr;
|
||||
|
||||
int64_t t0 = ggml_time_ms();
|
||||
_compute(8, x, false, &out, work_ctx);
|
||||
int64_t t1 = ggml_time_ms();
|
||||
|
||||
print_ggml_tensor(out);
|
||||
LOG_DEBUG("encode test done in %lldms", t1 - t0);
|
||||
}
|
||||
|
||||
if (false) {
|
||||
// CPU, z{1, 4, 8, 8}: Pass
|
||||
// CUDA, z{1, 4, 8, 8}: Pass
|
||||
// CPU, z{3, 4, 8, 8}: Wrong result
|
||||
// CUDA, z{3, 4, 8, 8}: Wrong result, and different from CPU result
|
||||
auto z = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 4, 1);
|
||||
ggml_set_f32(z, 0.5f);
|
||||
print_ggml_tensor(z);
|
||||
struct ggml_tensor* out = nullptr;
|
||||
|
||||
int64_t t0 = ggml_time_ms();
|
||||
_compute(8, z, true, &out, work_ctx);
|
||||
int64_t t1 = ggml_time_ms();
|
||||
|
||||
print_ggml_tensor(out);
|
||||
LOG_DEBUG("decode test done in %lldms", t1 - t0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
#endif // __AUTO_ENCODER_KL_HPP__
|
||||
@ -603,87 +603,6 @@ inline std::vector<int> generate_scm_mask(
|
||||
return mask;
|
||||
}
|
||||
|
||||
inline std::vector<int> get_scm_preset(const std::string& preset, int total_steps) {
|
||||
struct Preset {
|
||||
std::vector<int> compute_bins;
|
||||
std::vector<int> cache_bins;
|
||||
};
|
||||
|
||||
Preset slow = {{8, 3, 3, 2, 1, 1}, {1, 2, 2, 2, 3}};
|
||||
Preset medium = {{6, 2, 2, 2, 2, 1}, {1, 3, 3, 3, 3}};
|
||||
Preset fast = {{6, 1, 1, 1, 1, 1}, {1, 3, 4, 5, 4}};
|
||||
Preset ultra = {{4, 1, 1, 1, 1}, {2, 5, 6, 7}};
|
||||
|
||||
Preset* p = nullptr;
|
||||
if (preset == "slow" || preset == "s" || preset == "S")
|
||||
p = &slow;
|
||||
else if (preset == "medium" || preset == "m" || preset == "M")
|
||||
p = &medium;
|
||||
else if (preset == "fast" || preset == "f" || preset == "F")
|
||||
p = &fast;
|
||||
else if (preset == "ultra" || preset == "u" || preset == "U")
|
||||
p = &ultra;
|
||||
else
|
||||
return {};
|
||||
|
||||
if (total_steps != 28 && total_steps > 0) {
|
||||
float scale = static_cast<float>(total_steps) / 28.0f;
|
||||
std::vector<int> scaled_compute, scaled_cache;
|
||||
|
||||
for (int v : p->compute_bins) {
|
||||
scaled_compute.push_back(std::max(1, static_cast<int>(v * scale + 0.5f)));
|
||||
}
|
||||
for (int v : p->cache_bins) {
|
||||
scaled_cache.push_back(std::max(1, static_cast<int>(v * scale + 0.5f)));
|
||||
}
|
||||
|
||||
return generate_scm_mask(scaled_compute, scaled_cache, total_steps);
|
||||
}
|
||||
|
||||
return generate_scm_mask(p->compute_bins, p->cache_bins, total_steps);
|
||||
}
|
||||
|
||||
inline float get_preset_threshold(const std::string& preset) {
|
||||
if (preset == "slow" || preset == "s" || preset == "S")
|
||||
return 0.20f;
|
||||
if (preset == "medium" || preset == "m" || preset == "M")
|
||||
return 0.25f;
|
||||
if (preset == "fast" || preset == "f" || preset == "F")
|
||||
return 0.30f;
|
||||
if (preset == "ultra" || preset == "u" || preset == "U")
|
||||
return 0.34f;
|
||||
return 0.08f;
|
||||
}
|
||||
|
||||
inline int get_preset_warmup(const std::string& preset) {
|
||||
if (preset == "slow" || preset == "s" || preset == "S")
|
||||
return 8;
|
||||
if (preset == "medium" || preset == "m" || preset == "M")
|
||||
return 6;
|
||||
if (preset == "fast" || preset == "f" || preset == "F")
|
||||
return 6;
|
||||
if (preset == "ultra" || preset == "u" || preset == "U")
|
||||
return 4;
|
||||
return 8;
|
||||
}
|
||||
|
||||
inline int get_preset_Fn(const std::string& preset) {
|
||||
if (preset == "slow" || preset == "s" || preset == "S")
|
||||
return 8;
|
||||
if (preset == "medium" || preset == "m" || preset == "M")
|
||||
return 8;
|
||||
if (preset == "fast" || preset == "f" || preset == "F")
|
||||
return 6;
|
||||
if (preset == "ultra" || preset == "u" || preset == "U")
|
||||
return 4;
|
||||
return 8;
|
||||
}
|
||||
|
||||
inline int get_preset_Bn(const std::string& preset) {
|
||||
(void)preset;
|
||||
return 0;
|
||||
}
|
||||
|
||||
inline void parse_dbcache_options(const std::string& opts, DBCacheConfig& cfg) {
|
||||
if (opts.empty())
|
||||
return;
|
||||
|
||||
@ -377,6 +377,12 @@ __STATIC_INLINE__ void copy_ggml_tensor(struct ggml_tensor* dst, struct ggml_ten
|
||||
ggml_free(ctx);
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ ggml_tensor* ggml_ext_dup_and_cpy_tensor(ggml_context* ctx, ggml_tensor* src) {
|
||||
ggml_tensor* dup = ggml_dup_tensor(ctx, src);
|
||||
copy_ggml_tensor(dup, src);
|
||||
return dup;
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ float sigmoid(float x) {
|
||||
return 1 / (1.0f + expf(-x));
|
||||
}
|
||||
@ -637,7 +643,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_tensor_concat(struct ggml_context
|
||||
}
|
||||
|
||||
// convert values from [0, 1] to [-1, 1]
|
||||
__STATIC_INLINE__ void process_vae_input_tensor(struct ggml_tensor* src) {
|
||||
__STATIC_INLINE__ void scale_to_minus1_1(struct ggml_tensor* src) {
|
||||
int64_t nelements = ggml_nelements(src);
|
||||
float* data = (float*)src->data;
|
||||
for (int i = 0; i < nelements; i++) {
|
||||
@ -647,7 +653,7 @@ __STATIC_INLINE__ void process_vae_input_tensor(struct ggml_tensor* src) {
|
||||
}
|
||||
|
||||
// convert values from [-1, 1] to [0, 1]
|
||||
__STATIC_INLINE__ void process_vae_output_tensor(struct ggml_tensor* src) {
|
||||
__STATIC_INLINE__ void scale_to_0_1(struct ggml_tensor* src) {
|
||||
int64_t nelements = ggml_nelements(src);
|
||||
float* data = (float*)src->data;
|
||||
for (int i = 0; i < nelements; i++) {
|
||||
@ -834,7 +840,8 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
|
||||
const float tile_overlap_factor,
|
||||
const bool circular_x,
|
||||
const bool circular_y,
|
||||
on_tile_process on_processing) {
|
||||
on_tile_process on_processing,
|
||||
bool slient = false) {
|
||||
output = ggml_set_f32(output, 0);
|
||||
|
||||
int input_width = (int)input->ne[0];
|
||||
@ -864,8 +871,10 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
|
||||
float tile_overlap_factor_y;
|
||||
sd_tiling_calc_tiles(num_tiles_y, tile_overlap_factor_y, small_height, p_tile_size_y, tile_overlap_factor, circular_y);
|
||||
|
||||
if (!slient) {
|
||||
LOG_DEBUG("num tiles : %d, %d ", num_tiles_x, num_tiles_y);
|
||||
LOG_DEBUG("optimal overlap : %f, %f (targeting %f)", tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor);
|
||||
}
|
||||
|
||||
int tile_overlap_x = (int32_t)(p_tile_size_x * tile_overlap_factor_x);
|
||||
int non_tile_overlap_x = p_tile_size_x - tile_overlap_x;
|
||||
@ -896,7 +905,9 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
|
||||
params.mem_buffer = nullptr;
|
||||
params.no_alloc = false;
|
||||
|
||||
if (!slient) {
|
||||
LOG_DEBUG("tile work buffer size: %.2f MB", params.mem_size / 1024.f / 1024.f);
|
||||
}
|
||||
|
||||
// draft context
|
||||
struct ggml_context* tiles_ctx = ggml_init(params);
|
||||
@ -909,8 +920,10 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
|
||||
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size_x, input_tile_size_y, input->ne[2], input->ne[3]);
|
||||
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size_x, output_tile_size_y, output->ne[2], output->ne[3]);
|
||||
int num_tiles = num_tiles_x * num_tiles_y;
|
||||
if (!slient) {
|
||||
LOG_DEBUG("processing %i tiles", num_tiles);
|
||||
pretty_progress(0, num_tiles, 0.0f);
|
||||
}
|
||||
int tile_count = 1;
|
||||
bool last_y = false, last_x = false;
|
||||
float last_time = 0.0f;
|
||||
@ -960,9 +973,11 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
|
||||
}
|
||||
last_x = false;
|
||||
}
|
||||
if (!slient) {
|
||||
if (tile_count < num_tiles) {
|
||||
pretty_progress(num_tiles, num_tiles, last_time);
|
||||
}
|
||||
}
|
||||
ggml_free(tiles_ctx);
|
||||
}
|
||||
|
||||
|
||||
@ -1104,10 +1104,12 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) {
|
||||
has_middle_block_1 = true;
|
||||
}
|
||||
if (tensor_storage.name.find("model.diffusion_model.output_blocks.3.1.transformer_blocks.1") != std::string::npos) {
|
||||
if (tensor_storage.name.find("model.diffusion_model.output_blocks.3.1.transformer_blocks.1") != std::string::npos ||
|
||||
tensor_storage.name.find("unet.up_blocks.1.attentions.0.transformer_blocks.1") != std::string::npos) {
|
||||
has_output_block_311 = true;
|
||||
}
|
||||
if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos) {
|
||||
if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos ||
|
||||
tensor_storage.name.find("unet.up_blocks.2.attentions.1") != std::string::npos) {
|
||||
has_output_block_71 = true;
|
||||
}
|
||||
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
|
||||
|
||||
@ -1120,7 +1120,11 @@ std::string convert_tensor_name(std::string name, SDVersion version) {
|
||||
for (const auto& prefix : first_stage_model_prefix_vec) {
|
||||
if (starts_with(name, prefix)) {
|
||||
name = convert_first_stage_model_name(name.substr(prefix.size()), prefix);
|
||||
if (version == VERSION_SDXS) {
|
||||
name = "tae." + name;
|
||||
} else {
|
||||
name = prefix + name;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
#include "stable-diffusion.h"
|
||||
#include "util.h"
|
||||
|
||||
#include "auto_encoder_kl.hpp"
|
||||
#include "cache_dit.hpp"
|
||||
#include "conditioner.hpp"
|
||||
#include "control.hpp"
|
||||
@ -90,12 +91,17 @@ void calculate_alphas_cumprod(float* alphas_cumprod,
|
||||
}
|
||||
}
|
||||
|
||||
void suppress_pp(int step, int steps, float time, void* data) {
|
||||
(void)step;
|
||||
(void)steps;
|
||||
(void)time;
|
||||
(void)data;
|
||||
return;
|
||||
static float get_cache_reuse_threshold(const sd_cache_params_t& params) {
|
||||
float reuse_threshold = params.reuse_threshold;
|
||||
if (reuse_threshold == INFINITY) {
|
||||
if (params.mode == SD_CACHE_EASYCACHE) {
|
||||
reuse_threshold = 0.2;
|
||||
}
|
||||
else if (params.mode == SD_CACHE_UCACHE) {
|
||||
reuse_threshold = 1.0;
|
||||
}
|
||||
}
|
||||
return std::max(0.0f, reuse_threshold);
|
||||
}
|
||||
|
||||
/*=============================================== StableDiffusionGGML ================================================*/
|
||||
@ -118,8 +124,6 @@ public:
|
||||
std::shared_ptr<RNG> rng = std::make_shared<PhiloxRNG>();
|
||||
std::shared_ptr<RNG> sampler_rng = nullptr;
|
||||
int n_threads = -1;
|
||||
float scale_factor = 0.18215f;
|
||||
float shift_factor = 0.f;
|
||||
float default_flow_shift = INFINITY;
|
||||
|
||||
std::shared_ptr<Conditioner> cond_stage_model;
|
||||
@ -127,7 +131,7 @@ public:
|
||||
std::shared_ptr<DiffusionModel> diffusion_model;
|
||||
std::shared_ptr<DiffusionModel> high_noise_diffusion_model;
|
||||
std::shared_ptr<VAE> first_stage_model;
|
||||
std::shared_ptr<TinyAutoEncoder> tae_first_stage;
|
||||
std::shared_ptr<VAE> preview_vae;
|
||||
std::shared_ptr<ControlNet> control_net;
|
||||
std::shared_ptr<PhotoMakerIDEncoder> pmid_model;
|
||||
std::shared_ptr<LoraModel> pmid_lora;
|
||||
@ -138,7 +142,6 @@ public:
|
||||
bool apply_lora_immediately = false;
|
||||
|
||||
std::string taesd_path;
|
||||
bool use_tiny_autoencoder = false;
|
||||
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0, 0};
|
||||
bool offload_params_to_cpu = false;
|
||||
bool use_pmid = false;
|
||||
@ -239,10 +242,10 @@ public:
|
||||
n_threads = sd_ctx_params->n_threads;
|
||||
vae_decode_only = sd_ctx_params->vae_decode_only;
|
||||
free_params_immediately = sd_ctx_params->free_params_immediately;
|
||||
taesd_path = SAFE_STR(sd_ctx_params->taesd_path);
|
||||
use_tiny_autoencoder = taesd_path.size() > 0;
|
||||
offload_params_to_cpu = sd_ctx_params->offload_params_to_cpu;
|
||||
|
||||
bool use_tae = false;
|
||||
|
||||
rng = get_rng(sd_ctx_params->rng_type);
|
||||
if (sd_ctx_params->sampler_rng_type != RNG_TYPE_COUNT && sd_ctx_params->sampler_rng_type != sd_ctx_params->rng_type) {
|
||||
sampler_rng = get_rng(sd_ctx_params->sampler_rng_type);
|
||||
@ -332,6 +335,14 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
if (strlen(SAFE_STR(sd_ctx_params->taesd_path)) > 0) {
|
||||
LOG_INFO("loading tae from '%s'", sd_ctx_params->taesd_path);
|
||||
if (!model_loader.init_from_file(sd_ctx_params->taesd_path, "tae.")) {
|
||||
LOG_WARN("loading tae from '%s' failed", sd_ctx_params->taesd_path);
|
||||
}
|
||||
use_tae = true;
|
||||
}
|
||||
|
||||
model_loader.convert_tensors_name();
|
||||
|
||||
version = model_loader.get_sd_version();
|
||||
@ -400,22 +411,6 @@ public:
|
||||
apply_lora_immediately = false;
|
||||
}
|
||||
|
||||
if (sd_version_is_sdxl(version)) {
|
||||
scale_factor = 0.13025f;
|
||||
} else if (sd_version_is_sd3(version)) {
|
||||
scale_factor = 1.5305f;
|
||||
shift_factor = 0.0609f;
|
||||
} else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) {
|
||||
scale_factor = 0.3611f;
|
||||
shift_factor = 0.1159f;
|
||||
} else if (sd_version_is_wan(version) ||
|
||||
sd_version_is_qwen_image(version) ||
|
||||
sd_version_is_anima(version) ||
|
||||
sd_version_is_flux2(version)) {
|
||||
scale_factor = 1.0f;
|
||||
shift_factor = 0.f;
|
||||
}
|
||||
|
||||
if (sd_version_is_control(version)) {
|
||||
// Might need vae encode for control cond
|
||||
vae_decode_only = false;
|
||||
@ -424,6 +419,7 @@ public:
|
||||
bool tae_preview_only = sd_ctx_params->tae_preview_only;
|
||||
if (version == VERSION_SDXS) {
|
||||
tae_preview_only = false;
|
||||
use_tae = true;
|
||||
}
|
||||
|
||||
if (sd_ctx_params->circular_x || sd_ctx_params->circular_y) {
|
||||
@ -610,31 +606,46 @@ public:
|
||||
vae_backend = backend;
|
||||
}
|
||||
|
||||
if (!(use_tiny_autoencoder || version == VERSION_SDXS) || tae_preview_only) {
|
||||
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
|
||||
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
|
||||
auto create_tae = [&]() -> std::shared_ptr<VAE> {
|
||||
if (sd_version_is_wan(version) ||
|
||||
sd_version_is_qwen_image(version) ||
|
||||
sd_version_is_anima(version)) {
|
||||
return std::make_shared<TinyVideoAutoEncoder>(vae_backend,
|
||||
offload_params_to_cpu,
|
||||
tensor_storage_map,
|
||||
"decoder",
|
||||
vae_decode_only,
|
||||
version);
|
||||
|
||||
} else {
|
||||
auto model = std::make_shared<TinyImageAutoEncoder>(vae_backend,
|
||||
offload_params_to_cpu,
|
||||
tensor_storage_map,
|
||||
"decoder.layers",
|
||||
vae_decode_only,
|
||||
version);
|
||||
return model;
|
||||
}
|
||||
};
|
||||
|
||||
auto create_vae = [&]() -> std::shared_ptr<VAE> {
|
||||
if (sd_version_is_wan(version) ||
|
||||
sd_version_is_qwen_image(version) ||
|
||||
sd_version_is_anima(version)) {
|
||||
return std::make_shared<WAN::WanVAERunner>(vae_backend,
|
||||
offload_params_to_cpu,
|
||||
tensor_storage_map,
|
||||
"first_stage_model",
|
||||
vae_decode_only,
|
||||
version);
|
||||
first_stage_model->alloc_params_buffer();
|
||||
first_stage_model->get_param_tensors(tensors, "first_stage_model");
|
||||
} else if (version == VERSION_CHROMA_RADIANCE) {
|
||||
first_stage_model = std::make_shared<FakeVAE>(vae_backend,
|
||||
offload_params_to_cpu);
|
||||
} else {
|
||||
first_stage_model = std::make_shared<AutoEncoderKL>(vae_backend,
|
||||
auto model = std::make_shared<AutoEncoderKL>(vae_backend,
|
||||
offload_params_to_cpu,
|
||||
tensor_storage_map,
|
||||
"first_stage_model",
|
||||
vae_decode_only,
|
||||
false,
|
||||
version);
|
||||
if (sd_ctx_params->vae_conv_direct) {
|
||||
LOG_INFO("Using Conv2d direct in the vae model");
|
||||
first_stage_model->set_conv2d_direct_enabled(true);
|
||||
}
|
||||
if (sd_version_is_sdxl(version) &&
|
||||
(strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale || external_vae_is_invalid)) {
|
||||
float vae_conv_2d_scale = 1.f / 32.f;
|
||||
@ -642,35 +653,40 @@ public:
|
||||
"No valid VAE specified with --vae or --force-sdxl-vae-conv-scale flag set, "
|
||||
"using Conv2D scale %.3f",
|
||||
vae_conv_2d_scale);
|
||||
first_stage_model->set_conv2d_scale(vae_conv_2d_scale);
|
||||
model->set_conv2d_scale(vae_conv_2d_scale);
|
||||
}
|
||||
return model;
|
||||
}
|
||||
};
|
||||
|
||||
if (version == VERSION_CHROMA_RADIANCE) {
|
||||
LOG_INFO("using FakeVAE");
|
||||
first_stage_model = std::make_shared<FakeVAE>(version,
|
||||
vae_backend,
|
||||
offload_params_to_cpu);
|
||||
} else if (use_tae && !tae_preview_only) {
|
||||
LOG_INFO("using TAE for encoding / decoding");
|
||||
first_stage_model = create_tae();
|
||||
first_stage_model->alloc_params_buffer();
|
||||
first_stage_model->get_param_tensors(tensors, "tae");
|
||||
} else {
|
||||
LOG_INFO("using VAE for encoding / decoding");
|
||||
first_stage_model = create_vae();
|
||||
first_stage_model->alloc_params_buffer();
|
||||
first_stage_model->get_param_tensors(tensors, "first_stage_model");
|
||||
if (use_tae && tae_preview_only) {
|
||||
LOG_INFO("using TAE for preview");
|
||||
preview_vae = create_tae();
|
||||
preview_vae->alloc_params_buffer();
|
||||
preview_vae->get_param_tensors(tensors, "tae");
|
||||
}
|
||||
}
|
||||
if (use_tiny_autoencoder || version == VERSION_SDXS) {
|
||||
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
|
||||
tae_first_stage = std::make_shared<TinyVideoAutoEncoder>(vae_backend,
|
||||
offload_params_to_cpu,
|
||||
tensor_storage_map,
|
||||
"decoder",
|
||||
vae_decode_only,
|
||||
version);
|
||||
} else {
|
||||
tae_first_stage = std::make_shared<TinyImageAutoEncoder>(vae_backend,
|
||||
offload_params_to_cpu,
|
||||
tensor_storage_map,
|
||||
"decoder.layers",
|
||||
vae_decode_only,
|
||||
version);
|
||||
if (version == VERSION_SDXS) {
|
||||
tae_first_stage->alloc_params_buffer();
|
||||
tae_first_stage->get_param_tensors(tensors, "first_stage_model");
|
||||
}
|
||||
}
|
||||
|
||||
if (sd_ctx_params->vae_conv_direct) {
|
||||
LOG_INFO("Using Conv2d direct in the tae model");
|
||||
tae_first_stage->set_conv2d_direct_enabled(true);
|
||||
LOG_INFO("Using Conv2d direct in the vae model");
|
||||
first_stage_model->set_conv2d_direct_enabled(true);
|
||||
if (preview_vae) {
|
||||
preview_vae->set_conv2d_direct_enabled(true);
|
||||
}
|
||||
}
|
||||
|
||||
@ -743,8 +759,8 @@ public:
|
||||
if (first_stage_model) {
|
||||
first_stage_model->set_flash_attention_enabled(true);
|
||||
}
|
||||
if (tae_first_stage) {
|
||||
tae_first_stage->set_flash_attention_enabled(true);
|
||||
if (preview_vae) {
|
||||
preview_vae->set_flash_attention_enabled(true);
|
||||
}
|
||||
}
|
||||
|
||||
@ -782,7 +798,7 @@ public:
|
||||
|
||||
std::set<std::string> ignore_tensors;
|
||||
tensors["alphas_cumprod"] = alphas_cumprod_tensor;
|
||||
if (use_tiny_autoencoder) {
|
||||
if (use_tae && !tae_preview_only) {
|
||||
ignore_tensors.insert("first_stage_model.");
|
||||
}
|
||||
if (use_pmid) {
|
||||
@ -796,6 +812,7 @@ public:
|
||||
ignore_tensors.insert("first_stage_model.encoder");
|
||||
ignore_tensors.insert("first_stage_model.conv1");
|
||||
ignore_tensors.insert("first_stage_model.quant");
|
||||
ignore_tensors.insert("tae.encoder");
|
||||
ignore_tensors.insert("text_encoders.llm.visual.");
|
||||
}
|
||||
if (version == VERSION_OVIS_IMAGE) {
|
||||
@ -822,15 +839,9 @@ public:
|
||||
unet_params_mem_size += high_noise_diffusion_model->get_params_buffer_size();
|
||||
}
|
||||
size_t vae_params_mem_size = 0;
|
||||
if (!(use_tiny_autoencoder || version == VERSION_SDXS) || tae_preview_only) {
|
||||
vae_params_mem_size = first_stage_model->get_params_buffer_size();
|
||||
}
|
||||
if (use_tiny_autoencoder || version == VERSION_SDXS) {
|
||||
if (use_tiny_autoencoder && !tae_first_stage->load_from_file(taesd_path, n_threads)) {
|
||||
return false;
|
||||
}
|
||||
use_tiny_autoencoder = true; // now the processing is identical for VERSION_SDXS
|
||||
vae_params_mem_size = tae_first_stage->get_params_buffer_size();
|
||||
if (preview_vae) {
|
||||
vae_params_mem_size += preview_vae->get_params_buffer_size();
|
||||
}
|
||||
size_t control_net_params_mem_size = 0;
|
||||
if (control_net) {
|
||||
@ -983,7 +994,6 @@ public:
|
||||
}
|
||||
|
||||
ggml_free(ctx);
|
||||
use_tiny_autoencoder = use_tiny_autoencoder && !tae_preview_only;
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -1422,8 +1432,7 @@ public:
|
||||
ggml_ext_tensor_scale_inplace(noise, augmentation_level);
|
||||
ggml_ext_tensor_add_inplace(init_img, noise);
|
||||
}
|
||||
ggml_tensor* moments = vae_encode(work_ctx, init_img);
|
||||
c_concat = get_first_stage_encoding(work_ctx, moments);
|
||||
c_concat = encode_first_stage(work_ctx, init_img);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1475,14 +1484,6 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
void silent_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
|
||||
sd_progress_cb_t cb = sd_get_progress_callback();
|
||||
void* cbd = sd_get_progress_callback_data();
|
||||
sd_set_progress_callback((sd_progress_cb_t)suppress_pp, nullptr);
|
||||
sd_tiling(input, output, scale, tile_size, tile_overlap_factor, circular_x, circular_y, on_processing);
|
||||
sd_set_progress_callback(cb, cbd);
|
||||
}
|
||||
|
||||
void preview_image(ggml_context* work_ctx,
|
||||
int step,
|
||||
struct ggml_tensor* latents,
|
||||
@ -1575,37 +1576,14 @@ public:
|
||||
free(data);
|
||||
free(images);
|
||||
} else {
|
||||
if (preview_mode == PREVIEW_VAE) {
|
||||
process_latent_out(latents);
|
||||
if (vae_tiling_params.enabled) {
|
||||
// split latent in 32x32 tiles and compute in several steps
|
||||
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
||||
return first_stage_model->compute(n_threads, in, true, &out, nullptr);
|
||||
};
|
||||
silent_tiling(latents, result, get_vae_scale_factor(), 32, 0.5f, on_tiling);
|
||||
|
||||
if (preview_mode == PREVIEW_VAE || preview_mode == PREVIEW_TAE) {
|
||||
if (preview_vae) {
|
||||
latents = preview_vae->diffusion_to_vae_latents(work_ctx, latents);
|
||||
result = preview_vae->decode(n_threads, work_ctx, latents, vae_tiling_params, false, circular_x, circular_y, result, true);
|
||||
} else {
|
||||
first_stage_model->compute(n_threads, latents, true, &result, work_ctx);
|
||||
latents = first_stage_model->diffusion_to_vae_latents(work_ctx, latents);
|
||||
result = first_stage_model->decode(n_threads, work_ctx, latents, vae_tiling_params, false, circular_x, circular_y, result, true);
|
||||
}
|
||||
|
||||
first_stage_model->free_compute_buffer();
|
||||
process_vae_output_tensor(result);
|
||||
process_latent_in(latents);
|
||||
} else if (preview_mode == PREVIEW_TAE) {
|
||||
if (tae_first_stage == nullptr) {
|
||||
LOG_WARN("TAE not found for preview");
|
||||
return;
|
||||
}
|
||||
if (vae_tiling_params.enabled) {
|
||||
// split latent in 64x64 tiles and compute in several steps
|
||||
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
||||
return tae_first_stage->compute(n_threads, in, true, &out, nullptr);
|
||||
};
|
||||
silent_tiling(latents, result, get_vae_scale_factor(), 64, 0.5f, on_tiling);
|
||||
} else {
|
||||
tae_first_stage->compute(n_threads, latents, true, &result, work_ctx);
|
||||
}
|
||||
tae_first_stage->free_compute_buffer();
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
@ -1715,7 +1693,7 @@ public:
|
||||
} else {
|
||||
EasyCacheConfig easycache_config;
|
||||
easycache_config.enabled = true;
|
||||
easycache_config.reuse_threshold = std::max(0.0f, cache_params->reuse_threshold);
|
||||
easycache_config.reuse_threshold = get_cache_reuse_threshold(*cache_params);
|
||||
easycache_config.start_percent = cache_params->start_percent;
|
||||
easycache_config.end_percent = cache_params->end_percent;
|
||||
easycache_state.init(easycache_config, denoiser.get());
|
||||
@ -1736,7 +1714,7 @@ public:
|
||||
} else {
|
||||
UCacheConfig ucache_config;
|
||||
ucache_config.enabled = true;
|
||||
ucache_config.reuse_threshold = std::max(0.0f, cache_params->reuse_threshold);
|
||||
ucache_config.reuse_threshold = get_cache_reuse_threshold(*cache_params);
|
||||
ucache_config.start_percent = cache_params->start_percent;
|
||||
ucache_config.end_percent = cache_params->end_percent;
|
||||
ucache_config.error_decay_rate = std::max(0.0f, std::min(1.0f, cache_params->error_decay_rate));
|
||||
@ -1797,9 +1775,9 @@ public:
|
||||
}
|
||||
}
|
||||
} else if (cache_params->mode == SD_CACHE_SPECTRUM) {
|
||||
bool spectrum_supported = sd_version_is_unet(version);
|
||||
bool spectrum_supported = sd_version_is_unet(version) || sd_version_is_dit(version);
|
||||
if (!spectrum_supported) {
|
||||
LOG_WARN("Spectrum requested but not supported for this model type (only UNET models)");
|
||||
LOG_WARN("Spectrum requested but not supported for this model type (only UNET and DiT models)");
|
||||
} else {
|
||||
SpectrumConfig spectrum_config;
|
||||
spectrum_config.w = cache_params->spectrum_w;
|
||||
@ -1829,8 +1807,7 @@ public:
|
||||
}
|
||||
|
||||
size_t steps = sigmas.size() - 1;
|
||||
struct ggml_tensor* x = ggml_dup_tensor(work_ctx, init_latent);
|
||||
copy_ggml_tensor(x, init_latent);
|
||||
struct ggml_tensor* x = ggml_ext_dup_and_cpy_tensor(work_ctx, init_latent);
|
||||
|
||||
if (noise) {
|
||||
x = denoiser->noise_scaling(sigmas[0], noise, x);
|
||||
@ -2351,15 +2328,7 @@ public:
|
||||
}
|
||||
|
||||
int get_vae_scale_factor() {
|
||||
int vae_scale_factor = 8;
|
||||
if (version == VERSION_WAN2_2_TI2V) {
|
||||
vae_scale_factor = 16;
|
||||
} else if (sd_version_is_flux2(version)) {
|
||||
vae_scale_factor = 16;
|
||||
} else if (version == VERSION_CHROMA_RADIANCE) {
|
||||
vae_scale_factor = 1;
|
||||
}
|
||||
return vae_scale_factor;
|
||||
return first_stage_model->get_scale_factor();
|
||||
}
|
||||
|
||||
int get_diffusion_model_down_factor() {
|
||||
@ -2414,383 +2383,28 @@ public:
|
||||
} else {
|
||||
init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
|
||||
}
|
||||
ggml_set_f32(init_latent, shift_factor);
|
||||
ggml_set_f32(init_latent, 0.f);
|
||||
return init_latent;
|
||||
}
|
||||
|
||||
void get_latents_mean_std_vec(ggml_tensor* latent, int channel_dim, std::vector<float>& latents_mean_vec, std::vector<float>& latents_std_vec) {
|
||||
GGML_ASSERT(latent->ne[channel_dim] == 16 || latent->ne[channel_dim] == 48 || latent->ne[channel_dim] == 128);
|
||||
if (latent->ne[channel_dim] == 16) {
|
||||
latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f,
|
||||
0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f};
|
||||
latents_std_vec = {2.8184f, 1.4541f, 2.3275f, 2.6558f, 1.2196f, 1.7708f, 2.6052f, 2.0743f,
|
||||
3.2687f, 2.1526f, 2.8652f, 1.5579f, 1.6382f, 1.1253f, 2.8251f, 1.9160f};
|
||||
} else if (latent->ne[channel_dim] == 48) {
|
||||
latents_mean_vec = {-0.2289f, -0.0052f, -0.1323f, -0.2339f, -0.2799f, 0.0174f, 0.1838f, 0.1557f,
|
||||
-0.1382f, 0.0542f, 0.2813f, 0.0891f, 0.1570f, -0.0098f, 0.0375f, -0.1825f,
|
||||
-0.2246f, -0.1207f, -0.0698f, 0.5109f, 0.2665f, -0.2108f, -0.2158f, 0.2502f,
|
||||
-0.2055f, -0.0322f, 0.1109f, 0.1567f, -0.0729f, 0.0899f, -0.2799f, -0.1230f,
|
||||
-0.0313f, -0.1649f, 0.0117f, 0.0723f, -0.2839f, -0.2083f, -0.0520f, 0.3748f,
|
||||
0.0152f, 0.1957f, 0.1433f, -0.2944f, 0.3573f, -0.0548f, -0.1681f, -0.0667f};
|
||||
latents_std_vec = {
|
||||
0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f,
|
||||
0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f,
|
||||
0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f,
|
||||
0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f,
|
||||
0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f,
|
||||
0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f};
|
||||
} else if (latent->ne[channel_dim] == 128) {
|
||||
// flux2
|
||||
latents_mean_vec = {-0.0676f, -0.0715f, -0.0753f, -0.0745f, 0.0223f, 0.0180f, 0.0142f, 0.0184f,
|
||||
-0.0001f, -0.0063f, -0.0002f, -0.0031f, -0.0272f, -0.0281f, -0.0276f, -0.0290f,
|
||||
-0.0769f, -0.0672f, -0.0902f, -0.0892f, 0.0168f, 0.0152f, 0.0079f, 0.0086f,
|
||||
0.0083f, 0.0015f, 0.0003f, -0.0043f, -0.0439f, -0.0419f, -0.0438f, -0.0431f,
|
||||
-0.0102f, -0.0132f, -0.0066f, -0.0048f, -0.0311f, -0.0306f, -0.0279f, -0.0180f,
|
||||
0.0030f, 0.0015f, 0.0126f, 0.0145f, 0.0347f, 0.0338f, 0.0337f, 0.0283f,
|
||||
0.0020f, 0.0047f, 0.0047f, 0.0050f, 0.0123f, 0.0081f, 0.0081f, 0.0146f,
|
||||
0.0681f, 0.0679f, 0.0767f, 0.0732f, -0.0462f, -0.0474f, -0.0392f, -0.0511f,
|
||||
-0.0528f, -0.0477f, -0.0470f, -0.0517f, -0.0317f, -0.0316f, -0.0345f, -0.0283f,
|
||||
0.0510f, 0.0445f, 0.0578f, 0.0458f, -0.0412f, -0.0458f, -0.0487f, -0.0467f,
|
||||
-0.0088f, -0.0106f, -0.0088f, -0.0046f, -0.0376f, -0.0432f, -0.0436f, -0.0499f,
|
||||
0.0118f, 0.0166f, 0.0203f, 0.0279f, 0.0113f, 0.0129f, 0.0016f, 0.0072f,
|
||||
-0.0118f, -0.0018f, -0.0141f, -0.0054f, -0.0091f, -0.0138f, -0.0145f, -0.0187f,
|
||||
0.0323f, 0.0305f, 0.0259f, 0.0300f, 0.0540f, 0.0614f, 0.0495f, 0.0590f,
|
||||
-0.0511f, -0.0603f, -0.0478f, -0.0524f, -0.0227f, -0.0274f, -0.0154f, -0.0255f,
|
||||
-0.0572f, -0.0565f, -0.0518f, -0.0496f, 0.0116f, 0.0054f, 0.0163f, 0.0104f};
|
||||
latents_std_vec = {
|
||||
1.8029f, 1.7786f, 1.7868f, 1.7837f, 1.7717f, 1.7590f, 1.7610f, 1.7479f,
|
||||
1.7336f, 1.7373f, 1.7340f, 1.7343f, 1.8626f, 1.8527f, 1.8629f, 1.8589f,
|
||||
1.7593f, 1.7526f, 1.7556f, 1.7583f, 1.7363f, 1.7400f, 1.7355f, 1.7394f,
|
||||
1.7342f, 1.7246f, 1.7392f, 1.7304f, 1.7551f, 1.7513f, 1.7559f, 1.7488f,
|
||||
1.8449f, 1.8454f, 1.8550f, 1.8535f, 1.8240f, 1.7813f, 1.7854f, 1.7945f,
|
||||
1.8047f, 1.7876f, 1.7695f, 1.7676f, 1.7782f, 1.7667f, 1.7925f, 1.7848f,
|
||||
1.7579f, 1.7407f, 1.7483f, 1.7368f, 1.7961f, 1.7998f, 1.7920f, 1.7925f,
|
||||
1.7780f, 1.7747f, 1.7727f, 1.7749f, 1.7526f, 1.7447f, 1.7657f, 1.7495f,
|
||||
1.7775f, 1.7720f, 1.7813f, 1.7813f, 1.8162f, 1.8013f, 1.8023f, 1.8033f,
|
||||
1.7527f, 1.7331f, 1.7563f, 1.7482f, 1.7610f, 1.7507f, 1.7681f, 1.7613f,
|
||||
1.7665f, 1.7545f, 1.7828f, 1.7726f, 1.7896f, 1.7999f, 1.7864f, 1.7760f,
|
||||
1.7613f, 1.7625f, 1.7560f, 1.7577f, 1.7783f, 1.7671f, 1.7810f, 1.7799f,
|
||||
1.7201f, 1.7068f, 1.7265f, 1.7091f, 1.7793f, 1.7578f, 1.7502f, 1.7455f,
|
||||
1.7587f, 1.7500f, 1.7525f, 1.7362f, 1.7616f, 1.7572f, 1.7444f, 1.7430f,
|
||||
1.7509f, 1.7610f, 1.7634f, 1.7612f, 1.7254f, 1.7135f, 1.7321f, 1.7226f,
|
||||
1.7664f, 1.7624f, 1.7718f, 1.7664f, 1.7457f, 1.7441f, 1.7569f, 1.7530f};
|
||||
}
|
||||
}
|
||||
|
||||
void process_latent_in(ggml_tensor* latent) {
|
||||
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version) || sd_version_is_flux2(version)) {
|
||||
int channel_dim = sd_version_is_flux2(version) ? 2 : 3;
|
||||
std::vector<float> latents_mean_vec;
|
||||
std::vector<float> latents_std_vec;
|
||||
get_latents_mean_std_vec(latent, channel_dim, latents_mean_vec, latents_std_vec);
|
||||
|
||||
float mean;
|
||||
float std_;
|
||||
for (int i = 0; i < latent->ne[3]; i++) {
|
||||
if (channel_dim == 3) {
|
||||
mean = latents_mean_vec[i];
|
||||
std_ = latents_std_vec[i];
|
||||
}
|
||||
for (int j = 0; j < latent->ne[2]; j++) {
|
||||
if (channel_dim == 2) {
|
||||
mean = latents_mean_vec[i];
|
||||
std_ = latents_std_vec[i];
|
||||
}
|
||||
for (int k = 0; k < latent->ne[1]; k++) {
|
||||
for (int l = 0; l < latent->ne[0]; l++) {
|
||||
float value = ggml_ext_tensor_get_f32(latent, l, k, j, i);
|
||||
value = (value - mean) * scale_factor / std_;
|
||||
ggml_ext_tensor_set_f32(latent, value, l, k, j, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (version == VERSION_CHROMA_RADIANCE) {
|
||||
// pass
|
||||
} else {
|
||||
ggml_ext_tensor_iter(latent, [&](ggml_tensor* latent, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||
float value = ggml_ext_tensor_get_f32(latent, i0, i1, i2, i3);
|
||||
value = (value - shift_factor) * scale_factor;
|
||||
ggml_ext_tensor_set_f32(latent, value, i0, i1, i2, i3);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void process_latent_out(ggml_tensor* latent) {
|
||||
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version) || sd_version_is_flux2(version)) {
|
||||
int channel_dim = sd_version_is_flux2(version) ? 2 : 3;
|
||||
std::vector<float> latents_mean_vec;
|
||||
std::vector<float> latents_std_vec;
|
||||
get_latents_mean_std_vec(latent, channel_dim, latents_mean_vec, latents_std_vec);
|
||||
|
||||
float mean;
|
||||
float std_;
|
||||
for (int i = 0; i < latent->ne[3]; i++) {
|
||||
if (channel_dim == 3) {
|
||||
mean = latents_mean_vec[i];
|
||||
std_ = latents_std_vec[i];
|
||||
}
|
||||
for (int j = 0; j < latent->ne[2]; j++) {
|
||||
if (channel_dim == 2) {
|
||||
mean = latents_mean_vec[i];
|
||||
std_ = latents_std_vec[i];
|
||||
}
|
||||
for (int k = 0; k < latent->ne[1]; k++) {
|
||||
for (int l = 0; l < latent->ne[0]; l++) {
|
||||
float value = ggml_ext_tensor_get_f32(latent, l, k, j, i);
|
||||
value = value * std_ / scale_factor + mean;
|
||||
ggml_ext_tensor_set_f32(latent, value, l, k, j, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (version == VERSION_CHROMA_RADIANCE) {
|
||||
// pass
|
||||
} else {
|
||||
ggml_ext_tensor_iter(latent, [&](ggml_tensor* latent, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||
float value = ggml_ext_tensor_get_f32(latent, i0, i1, i2, i3);
|
||||
value = (value / scale_factor) + shift_factor;
|
||||
ggml_ext_tensor_set_f32(latent, value, i0, i1, i2, i3);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void get_tile_sizes(int& tile_size_x,
|
||||
int& tile_size_y,
|
||||
float& tile_overlap,
|
||||
const sd_tiling_params_t& params,
|
||||
int64_t latent_x,
|
||||
int64_t latent_y,
|
||||
float encoding_factor = 1.0f) {
|
||||
tile_overlap = std::max(std::min(params.target_overlap, 0.5f), 0.0f);
|
||||
auto get_tile_size = [&](int requested_size, float factor, int64_t latent_size) {
|
||||
const int default_tile_size = 32;
|
||||
const int min_tile_dimension = 4;
|
||||
int tile_size = default_tile_size;
|
||||
// factor <= 1 means simple fraction of the latent dimension
|
||||
// factor > 1 means number of tiles across that dimension
|
||||
if (factor > 0.f) {
|
||||
if (factor > 1.0)
|
||||
factor = 1 / (factor - factor * tile_overlap + tile_overlap);
|
||||
tile_size = static_cast<int>(std::round(latent_size * factor));
|
||||
} else if (requested_size >= min_tile_dimension) {
|
||||
tile_size = requested_size;
|
||||
}
|
||||
tile_size = static_cast<int>(tile_size * encoding_factor);
|
||||
return std::max(std::min(tile_size, static_cast<int>(latent_size)), min_tile_dimension);
|
||||
};
|
||||
|
||||
tile_size_x = get_tile_size(params.tile_size_x, params.rel_size_x, latent_x);
|
||||
tile_size_y = get_tile_size(params.tile_size_y, params.rel_size_y, latent_y);
|
||||
}
|
||||
|
||||
ggml_tensor* vae_encode(ggml_context* work_ctx, ggml_tensor* x) {
|
||||
int64_t t0 = ggml_time_ms();
|
||||
ggml_tensor* result = nullptr;
|
||||
const int vae_scale_factor = get_vae_scale_factor();
|
||||
int64_t W = x->ne[0] / vae_scale_factor;
|
||||
int64_t H = x->ne[1] / vae_scale_factor;
|
||||
int64_t C = get_latent_channel();
|
||||
if (vae_tiling_params.enabled) {
|
||||
// TODO wan2.2 vae support?
|
||||
int64_t ne2;
|
||||
int64_t ne3;
|
||||
if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
|
||||
ne2 = 1;
|
||||
ne3 = C * x->ne[3];
|
||||
} else {
|
||||
int64_t out_channels = C;
|
||||
bool encode_outputs_mu = use_tiny_autoencoder ||
|
||||
sd_version_is_wan(version) ||
|
||||
sd_version_is_flux2(version) ||
|
||||
version == VERSION_CHROMA_RADIANCE;
|
||||
if (!encode_outputs_mu) {
|
||||
out_channels *= 2;
|
||||
}
|
||||
ne2 = out_channels;
|
||||
ne3 = x->ne[3];
|
||||
}
|
||||
result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, ne2, ne3);
|
||||
}
|
||||
|
||||
if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
|
||||
x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]);
|
||||
}
|
||||
|
||||
if (!use_tiny_autoencoder) {
|
||||
process_vae_input_tensor(x);
|
||||
if (vae_tiling_params.enabled) {
|
||||
float tile_overlap;
|
||||
int tile_size_x, tile_size_y;
|
||||
// multiply tile size for encode to keep the compute buffer size consistent
|
||||
get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, vae_tiling_params, W, H, 1.30539f);
|
||||
|
||||
LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
|
||||
|
||||
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
||||
return first_stage_model->compute(n_threads, in, false, &out, work_ctx);
|
||||
};
|
||||
sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, circular_x, circular_y, on_tiling);
|
||||
} else {
|
||||
first_stage_model->compute(n_threads, x, false, &result, work_ctx);
|
||||
}
|
||||
first_stage_model->free_compute_buffer();
|
||||
} else {
|
||||
if (vae_tiling_params.enabled) {
|
||||
// split latent in 32x32 tiles and compute in several steps
|
||||
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
||||
return tae_first_stage->compute(n_threads, in, false, &out, nullptr);
|
||||
};
|
||||
sd_tiling(x, result, vae_scale_factor, 64, 0.5f, circular_x, circular_y, on_tiling);
|
||||
} else {
|
||||
tae_first_stage->compute(n_threads, x, false, &result, work_ctx);
|
||||
}
|
||||
tae_first_stage->free_compute_buffer();
|
||||
}
|
||||
|
||||
int64_t t1 = ggml_time_ms();
|
||||
LOG_DEBUG("computing vae encode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
||||
return result;
|
||||
}
|
||||
|
||||
ggml_tensor* gaussian_latent_sample(ggml_context* work_ctx, ggml_tensor* moments) {
|
||||
// ldm.modules.distributions.distributions.DiagonalGaussianDistribution.sample
|
||||
ggml_tensor* latent = ggml_new_tensor_4d(work_ctx, moments->type, moments->ne[0], moments->ne[1], moments->ne[2] / 2, moments->ne[3]);
|
||||
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, latent);
|
||||
ggml_ext_im_set_randn_f32(noise, rng);
|
||||
{
|
||||
float mean = 0;
|
||||
float logvar = 0;
|
||||
float value = 0;
|
||||
float std_ = 0;
|
||||
for (int i = 0; i < latent->ne[3]; i++) {
|
||||
for (int j = 0; j < latent->ne[2]; j++) {
|
||||
for (int k = 0; k < latent->ne[1]; k++) {
|
||||
for (int l = 0; l < latent->ne[0]; l++) {
|
||||
mean = ggml_ext_tensor_get_f32(moments, l, k, j, i);
|
||||
logvar = ggml_ext_tensor_get_f32(moments, l, k, j + (int)latent->ne[2], i);
|
||||
logvar = std::max(-30.0f, std::min(logvar, 20.0f));
|
||||
std_ = std::exp(0.5f * logvar);
|
||||
value = mean + std_ * ggml_ext_tensor_get_f32(noise, l, k, j, i);
|
||||
// printf("%d %d %d %d -> %f\n", i, j, k, l, value);
|
||||
ggml_ext_tensor_set_f32(latent, value, l, k, j, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return latent;
|
||||
}
|
||||
|
||||
ggml_tensor* get_first_stage_encoding(ggml_context* work_ctx, ggml_tensor* vae_output) {
|
||||
ggml_tensor* latent;
|
||||
if (use_tiny_autoencoder ||
|
||||
sd_version_is_qwen_image(version) ||
|
||||
sd_version_is_anima(version) ||
|
||||
sd_version_is_wan(version) ||
|
||||
sd_version_is_flux2(version) ||
|
||||
version == VERSION_CHROMA_RADIANCE) {
|
||||
latent = vae_output;
|
||||
} else if (version == VERSION_SD1_PIX2PIX) {
|
||||
latent = ggml_view_3d(work_ctx,
|
||||
vae_output,
|
||||
vae_output->ne[0],
|
||||
vae_output->ne[1],
|
||||
vae_output->ne[2] / 2,
|
||||
vae_output->nb[1],
|
||||
vae_output->nb[2],
|
||||
0);
|
||||
} else {
|
||||
latent = gaussian_latent_sample(work_ctx, vae_output);
|
||||
}
|
||||
if (!use_tiny_autoencoder && version != VERSION_SD1_PIX2PIX) {
|
||||
process_latent_in(latent);
|
||||
}
|
||||
if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
|
||||
latent = ggml_reshape_4d(work_ctx, latent, latent->ne[0], latent->ne[1], latent->ne[3], 1);
|
||||
}
|
||||
return latent;
|
||||
ggml_tensor* encode_to_vae_latents(ggml_context* work_ctx, ggml_tensor* x) {
|
||||
ggml_tensor* vae_output = first_stage_model->encode(n_threads, work_ctx, x, vae_tiling_params, circular_x, circular_y);
|
||||
ggml_tensor* latents = first_stage_model->vae_output_to_latents(work_ctx, vae_output, rng);
|
||||
return latents;
|
||||
}
|
||||
|
||||
ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x) {
|
||||
ggml_tensor* vae_output = vae_encode(work_ctx, x);
|
||||
return get_first_stage_encoding(work_ctx, vae_output);
|
||||
ggml_tensor* latents = encode_to_vae_latents(work_ctx, x);
|
||||
if (version != VERSION_SD1_PIX2PIX) {
|
||||
latents = first_stage_model->vae_to_diffuison_latents(work_ctx, latents);
|
||||
}
|
||||
return latents;
|
||||
}
|
||||
|
||||
ggml_tensor* decode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) {
|
||||
const int vae_scale_factor = get_vae_scale_factor();
|
||||
int64_t W = x->ne[0] * vae_scale_factor;
|
||||
int64_t H = x->ne[1] * vae_scale_factor;
|
||||
int64_t C = 3;
|
||||
ggml_tensor* result = nullptr;
|
||||
if (decode_video) {
|
||||
int64_t T = x->ne[2];
|
||||
if (sd_version_is_wan(version)) {
|
||||
T = ((T - 1) * 4) + 1;
|
||||
}
|
||||
result = ggml_new_tensor_4d(work_ctx,
|
||||
GGML_TYPE_F32,
|
||||
W,
|
||||
H,
|
||||
T,
|
||||
3);
|
||||
} else {
|
||||
result = ggml_new_tensor_4d(work_ctx,
|
||||
GGML_TYPE_F32,
|
||||
W,
|
||||
H,
|
||||
C,
|
||||
x->ne[3]);
|
||||
}
|
||||
int64_t t0 = ggml_time_ms();
|
||||
if (!use_tiny_autoencoder) {
|
||||
if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
|
||||
x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]);
|
||||
}
|
||||
process_latent_out(x);
|
||||
// x = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
|
||||
if (vae_tiling_params.enabled) {
|
||||
float tile_overlap;
|
||||
int tile_size_x, tile_size_y;
|
||||
get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, vae_tiling_params, x->ne[0], x->ne[1]);
|
||||
|
||||
LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
|
||||
|
||||
// split latent in 32x32 tiles and compute in several steps
|
||||
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
||||
return first_stage_model->compute(n_threads, in, true, &out, nullptr);
|
||||
};
|
||||
sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, circular_x, circular_y, on_tiling);
|
||||
} else {
|
||||
if (!first_stage_model->compute(n_threads, x, true, &result, work_ctx)) {
|
||||
LOG_ERROR("Failed to decode latetnts");
|
||||
first_stage_model->free_compute_buffer();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
first_stage_model->free_compute_buffer();
|
||||
process_vae_output_tensor(result);
|
||||
} else {
|
||||
if (vae_tiling_params.enabled) {
|
||||
// split latent in 64x64 tiles and compute in several steps
|
||||
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
||||
return tae_first_stage->compute(n_threads, in, true, &out);
|
||||
};
|
||||
sd_tiling(x, result, vae_scale_factor, 64, 0.5f, circular_x, circular_y, on_tiling);
|
||||
} else {
|
||||
if (!tae_first_stage->compute(n_threads, x, true, &result)) {
|
||||
LOG_ERROR("Failed to decode latetnts");
|
||||
tae_first_stage->free_compute_buffer();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
tae_first_stage->free_compute_buffer();
|
||||
}
|
||||
|
||||
int64_t t1 = ggml_time_ms();
|
||||
LOG_DEBUG("computing vae decode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
||||
ggml_ext_tensor_clamp_inplace(result, 0.0f, 1.0f);
|
||||
return result;
|
||||
x = first_stage_model->diffusion_to_vae_latents(work_ctx, x);
|
||||
x = first_stage_model->decode(n_threads, work_ctx, x, vae_tiling_params, decode_video, circular_x, circular_y);
|
||||
return x;
|
||||
}
|
||||
|
||||
void set_flow_shift(float flow_shift = INFINITY) {
|
||||
@ -2983,7 +2597,7 @@ enum lora_apply_mode_t str_to_lora_apply_mode(const char* str) {
|
||||
void sd_cache_params_init(sd_cache_params_t* cache_params) {
|
||||
*cache_params = {};
|
||||
cache_params->mode = SD_CACHE_DISABLED;
|
||||
cache_params->reuse_threshold = 1.0f;
|
||||
cache_params->reuse_threshold = INFINITY;
|
||||
cache_params->start_percent = 0.15f;
|
||||
cache_params->end_percent = 0.95f;
|
||||
cache_params->error_decay_rate = 1.0f;
|
||||
@ -3229,7 +2843,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
|
||||
snprintf(buf + strlen(buf), 4096 - strlen(buf),
|
||||
"cache: %s (threshold=%.3f, start=%.2f, end=%.2f)\n",
|
||||
cache_mode_str,
|
||||
sd_img_gen_params->cache.reuse_threshold,
|
||||
get_cache_reuse_threshold(sd_img_gen_params->cache),
|
||||
sd_img_gen_params->cache.start_percent,
|
||||
sd_img_gen_params->cache.end_percent);
|
||||
free(sample_params_str);
|
||||
@ -3560,7 +3174,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
||||
|
||||
int64_t t4 = ggml_time_ms();
|
||||
LOG_INFO("decode_first_stage completed, taking %.2fs", (t4 - t3) * 1.0f / 1000);
|
||||
if (sd_ctx->sd->free_params_immediately && !sd_ctx->sd->use_tiny_autoencoder) {
|
||||
if (sd_ctx->sd->free_params_immediately) {
|
||||
sd_ctx->sd->first_stage_model->free_params_buffer();
|
||||
}
|
||||
|
||||
@ -3609,15 +3223,15 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
||||
if (sd_ctx->sd->first_stage_model) {
|
||||
sd_ctx->sd->first_stage_model->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
|
||||
}
|
||||
if (sd_ctx->sd->tae_first_stage) {
|
||||
sd_ctx->sd->tae_first_stage->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
|
||||
if (sd_ctx->sd->preview_vae) {
|
||||
sd_ctx->sd->preview_vae->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
|
||||
}
|
||||
} else {
|
||||
int tile_size_x, tile_size_y;
|
||||
float _overlap;
|
||||
int latent_size_x = width / sd_ctx->sd->get_vae_scale_factor();
|
||||
int latent_size_y = height / sd_ctx->sd->get_vae_scale_factor();
|
||||
sd_ctx->sd->get_tile_sizes(tile_size_x, tile_size_y, _overlap, sd_img_gen_params->vae_tiling_params, latent_size_x, latent_size_y);
|
||||
sd_ctx->sd->first_stage_model->get_tile_sizes(tile_size_x, tile_size_y, _overlap, sd_img_gen_params->vae_tiling_params, latent_size_x, latent_size_y);
|
||||
|
||||
// force disable circular padding for vae if tiling is enabled unless latent is smaller than tile size
|
||||
// otherwise it will cause artifacts at the edges of the tiles
|
||||
@ -3627,8 +3241,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
||||
if (sd_ctx->sd->first_stage_model) {
|
||||
sd_ctx->sd->first_stage_model->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
|
||||
}
|
||||
if (sd_ctx->sd->tae_first_stage) {
|
||||
sd_ctx->sd->tae_first_stage->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
|
||||
if (sd_ctx->sd->preview_vae) {
|
||||
sd_ctx->sd->preview_vae->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y);
|
||||
}
|
||||
|
||||
// disable circular tiling if it's enabled for the VAE
|
||||
@ -4105,14 +3719,13 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
sd_image_to_ggml_tensor(sd_vid_gen_params->init_image, init_img);
|
||||
init_img = ggml_reshape_4d(work_ctx, init_img, width, height, 1, 3);
|
||||
|
||||
auto init_image_latent = sd_ctx->sd->vae_encode(work_ctx, init_img); // [b*c, 1, h/16, w/16]
|
||||
auto init_image_latent = sd_ctx->sd->encode_to_vae_latents(work_ctx, init_img); // [b*c, 1, h/16, w/16]
|
||||
|
||||
init_latent = sd_ctx->sd->generate_init_latent(work_ctx, width, height, frames, true);
|
||||
denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1);
|
||||
ggml_set_f32(denoise_mask, 1.f);
|
||||
|
||||
if (!sd_ctx->sd->use_tiny_autoencoder)
|
||||
sd_ctx->sd->process_latent_out(init_latent);
|
||||
init_latent = sd_ctx->sd->first_stage_model->diffusion_to_vae_latents(work_ctx, init_latent);
|
||||
|
||||
ggml_ext_tensor_iter(init_image_latent, [&](ggml_tensor* t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||
float value = ggml_ext_tensor_get_f32(t, i0, i1, i2, i3);
|
||||
@ -4122,8 +3735,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
}
|
||||
});
|
||||
|
||||
if (!sd_ctx->sd->use_tiny_autoencoder)
|
||||
sd_ctx->sd->process_latent_in(init_latent);
|
||||
init_latent = sd_ctx->sd->first_stage_model->vae_to_diffuison_latents(work_ctx, init_latent);
|
||||
|
||||
int64_t t2 = ggml_time_ms();
|
||||
LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1);
|
||||
@ -4346,7 +3958,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
struct ggml_tensor* vid = sd_ctx->sd->decode_first_stage(work_ctx, final_latent, true);
|
||||
int64_t t5 = ggml_time_ms();
|
||||
LOG_INFO("decode_first_stage completed, taking %.2fs", (t5 - t4) * 1.0f / 1000);
|
||||
if (sd_ctx->sd->free_params_immediately && !sd_ctx->sd->use_tiny_autoencoder) {
|
||||
if (sd_ctx->sd->free_params_immediately) {
|
||||
sd_ctx->sd->first_stage_model->free_params_buffer();
|
||||
}
|
||||
|
||||
|
||||
121
src/tae.hpp
121
src/tae.hpp
@ -442,10 +442,12 @@ protected:
|
||||
bool decode_only;
|
||||
SDVersion version;
|
||||
|
||||
public:
|
||||
int z_channels = 16;
|
||||
|
||||
public:
|
||||
TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2)
|
||||
: decode_only(decode_only), version(version) {
|
||||
int z_channels = 16;
|
||||
int patch = 1;
|
||||
if (version == VERSION_WAN2_2_TI2V) {
|
||||
z_channels = 48;
|
||||
@ -494,10 +496,12 @@ protected:
|
||||
bool decode_only;
|
||||
bool taef2 = false;
|
||||
|
||||
public:
|
||||
int z_channels = 4;
|
||||
|
||||
public:
|
||||
TAESD(bool decode_only = true, SDVersion version = VERSION_SD1)
|
||||
: decode_only(decode_only) {
|
||||
int z_channels = 4;
|
||||
bool use_midblock_gn = false;
|
||||
taef2 = sd_version_is_flux2(version);
|
||||
|
||||
@ -533,20 +537,7 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
struct TinyAutoEncoder : public GGMLRunner {
|
||||
TinyAutoEncoder(ggml_backend_t backend, bool offload_params_to_cpu)
|
||||
: GGMLRunner(backend, offload_params_to_cpu) {}
|
||||
virtual bool compute(const int n_threads,
|
||||
struct ggml_tensor* z,
|
||||
bool decode_graph,
|
||||
struct ggml_tensor** output,
|
||||
struct ggml_context* output_ctx = nullptr) = 0;
|
||||
|
||||
virtual bool load_from_file(const std::string& file_path, int n_threads) = 0;
|
||||
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) = 0;
|
||||
};
|
||||
|
||||
struct TinyImageAutoEncoder : public TinyAutoEncoder {
|
||||
struct TinyImageAutoEncoder : public VAE {
|
||||
TAESD taesd;
|
||||
bool decode_only = false;
|
||||
|
||||
@ -558,7 +549,8 @@ struct TinyImageAutoEncoder : public TinyAutoEncoder {
|
||||
SDVersion version = VERSION_SD1)
|
||||
: decode_only(decoder_only),
|
||||
taesd(decoder_only, version),
|
||||
TinyAutoEncoder(backend, offload_params_to_cpu) {
|
||||
VAE(version, backend, offload_params_to_cpu) {
|
||||
scale_input = false;
|
||||
taesd.init(params_ctx, tensor_storage_map, prefix);
|
||||
}
|
||||
|
||||
@ -566,37 +558,26 @@ struct TinyImageAutoEncoder : public TinyAutoEncoder {
|
||||
return "taesd";
|
||||
}
|
||||
|
||||
bool load_from_file(const std::string& file_path, int n_threads) {
|
||||
LOG_INFO("loading taesd from '%s', decode_only = %s", file_path.c_str(), decode_only ? "true" : "false");
|
||||
alloc_params_buffer();
|
||||
std::map<std::string, ggml_tensor*> taesd_tensors;
|
||||
taesd.get_param_tensors(taesd_tensors);
|
||||
std::set<std::string> ignore_tensors;
|
||||
if (decode_only) {
|
||||
ignore_tensors.insert("encoder.");
|
||||
}
|
||||
|
||||
ModelLoader model_loader;
|
||||
if (!model_loader.init_from_file_and_convert_name(file_path)) {
|
||||
LOG_ERROR("init taesd model loader from file failed: '%s'", file_path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
bool success = model_loader.load_tensors(taesd_tensors, ignore_tensors, n_threads);
|
||||
|
||||
if (!success) {
|
||||
LOG_ERROR("load tae tensors from model loader failed");
|
||||
return false;
|
||||
}
|
||||
|
||||
LOG_INFO("taesd model loaded");
|
||||
return success;
|
||||
}
|
||||
|
||||
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
||||
taesd.get_param_tensors(tensors, prefix);
|
||||
}
|
||||
|
||||
ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr<RNG> rng) {
|
||||
return vae_output;
|
||||
}
|
||||
|
||||
ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
||||
return ggml_ext_dup_and_cpy_tensor(work_ctx, latents);
|
||||
}
|
||||
|
||||
ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
||||
return ggml_ext_dup_and_cpy_tensor(work_ctx, latents);
|
||||
}
|
||||
|
||||
int get_encoder_output_channels(int input_channels) {
|
||||
return taesd.z_channels;
|
||||
}
|
||||
|
||||
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
|
||||
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||
z = to_backend(z);
|
||||
@ -606,7 +587,7 @@ struct TinyImageAutoEncoder : public TinyAutoEncoder {
|
||||
return gf;
|
||||
}
|
||||
|
||||
bool compute(const int n_threads,
|
||||
bool _compute(const int n_threads,
|
||||
struct ggml_tensor* z,
|
||||
bool decode_graph,
|
||||
struct ggml_tensor** output,
|
||||
@ -619,7 +600,7 @@ struct TinyImageAutoEncoder : public TinyAutoEncoder {
|
||||
}
|
||||
};
|
||||
|
||||
struct TinyVideoAutoEncoder : public TinyAutoEncoder {
|
||||
struct TinyVideoAutoEncoder : public VAE {
|
||||
TAEHV taehv;
|
||||
bool decode_only = false;
|
||||
|
||||
@ -631,7 +612,8 @@ struct TinyVideoAutoEncoder : public TinyAutoEncoder {
|
||||
SDVersion version = VERSION_WAN2)
|
||||
: decode_only(decoder_only),
|
||||
taehv(decoder_only, version),
|
||||
TinyAutoEncoder(backend, offload_params_to_cpu) {
|
||||
VAE(version, backend, offload_params_to_cpu) {
|
||||
scale_input = false;
|
||||
taehv.init(params_ctx, tensor_storage_map, prefix);
|
||||
}
|
||||
|
||||
@ -639,37 +621,26 @@ struct TinyVideoAutoEncoder : public TinyAutoEncoder {
|
||||
return "taehv";
|
||||
}
|
||||
|
||||
bool load_from_file(const std::string& file_path, int n_threads) {
|
||||
LOG_INFO("loading taehv from '%s', decode_only = %s", file_path.c_str(), decode_only ? "true" : "false");
|
||||
alloc_params_buffer();
|
||||
std::map<std::string, ggml_tensor*> taehv_tensors;
|
||||
taehv.get_param_tensors(taehv_tensors);
|
||||
std::set<std::string> ignore_tensors;
|
||||
if (decode_only) {
|
||||
ignore_tensors.insert("encoder.");
|
||||
}
|
||||
|
||||
ModelLoader model_loader;
|
||||
if (!model_loader.init_from_file(file_path)) {
|
||||
LOG_ERROR("init taehv model loader from file failed: '%s'", file_path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
bool success = model_loader.load_tensors(taehv_tensors, ignore_tensors, n_threads);
|
||||
|
||||
if (!success) {
|
||||
LOG_ERROR("load tae tensors from model loader failed");
|
||||
return false;
|
||||
}
|
||||
|
||||
LOG_INFO("taehv model loaded");
|
||||
return success;
|
||||
}
|
||||
|
||||
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
|
||||
taehv.get_param_tensors(tensors, prefix);
|
||||
}
|
||||
|
||||
ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr<RNG> rng) {
|
||||
return vae_output;
|
||||
}
|
||||
|
||||
ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
||||
return ggml_ext_dup_and_cpy_tensor(work_ctx, latents);
|
||||
}
|
||||
|
||||
ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
||||
return ggml_ext_dup_and_cpy_tensor(work_ctx, latents);
|
||||
}
|
||||
|
||||
int get_encoder_output_channels(int input_channels) {
|
||||
return taehv.z_channels;
|
||||
}
|
||||
|
||||
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
|
||||
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||
z = to_backend(z);
|
||||
@ -679,7 +650,7 @@ struct TinyVideoAutoEncoder : public TinyAutoEncoder {
|
||||
return gf;
|
||||
}
|
||||
|
||||
bool compute(const int n_threads,
|
||||
bool _compute(const int n_threads,
|
||||
struct ggml_tensor* z,
|
||||
bool decode_graph,
|
||||
struct ggml_tensor** output,
|
||||
|
||||
935
src/vae.hpp
935
src/vae.hpp
@ -3,631 +3,202 @@
|
||||
|
||||
#include "common_block.hpp"
|
||||
|
||||
/*================================================== AutoEncoderKL ===================================================*/
|
||||
|
||||
#define VAE_GRAPH_SIZE 20480
|
||||
|
||||
class ResnetBlock : public UnaryBlock {
|
||||
protected:
|
||||
int64_t in_channels;
|
||||
int64_t out_channels;
|
||||
|
||||
public:
|
||||
ResnetBlock(int64_t in_channels,
|
||||
int64_t out_channels)
|
||||
: in_channels(in_channels),
|
||||
out_channels(out_channels) {
|
||||
// temb_channels is always 0
|
||||
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
|
||||
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
|
||||
|
||||
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(out_channels));
|
||||
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
|
||||
|
||||
if (out_channels != in_channels) {
|
||||
blocks["nin_shortcut"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, {1, 1}));
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||
// x: [N, in_channels, h, w]
|
||||
// t_emb is always None
|
||||
auto norm1 = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm1"]);
|
||||
auto conv1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv1"]);
|
||||
auto norm2 = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm2"]);
|
||||
auto conv2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv2"]);
|
||||
|
||||
auto h = x;
|
||||
h = norm1->forward(ctx, h);
|
||||
h = ggml_silu_inplace(ctx->ggml_ctx, h); // swish
|
||||
h = conv1->forward(ctx, h);
|
||||
// return h;
|
||||
|
||||
h = norm2->forward(ctx, h);
|
||||
h = ggml_silu_inplace(ctx->ggml_ctx, h); // swish
|
||||
// dropout, skip for inference
|
||||
h = conv2->forward(ctx, h);
|
||||
|
||||
// skip connection
|
||||
if (out_channels != in_channels) {
|
||||
auto nin_shortcut = std::dynamic_pointer_cast<Conv2d>(blocks["nin_shortcut"]);
|
||||
|
||||
x = nin_shortcut->forward(ctx, x); // [N, out_channels, h, w]
|
||||
}
|
||||
|
||||
h = ggml_add(ctx->ggml_ctx, h, x);
|
||||
return h; // [N, out_channels, h, w]
|
||||
}
|
||||
};
|
||||
|
||||
class AttnBlock : public UnaryBlock {
|
||||
protected:
|
||||
int64_t in_channels;
|
||||
bool use_linear;
|
||||
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") {
|
||||
auto iter = tensor_storage_map.find(prefix + "proj_out.weight");
|
||||
if (iter != tensor_storage_map.end()) {
|
||||
if (iter->second.n_dims == 4 && use_linear) {
|
||||
use_linear = false;
|
||||
blocks["q"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
|
||||
blocks["k"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
|
||||
blocks["v"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
|
||||
blocks["proj_out"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
|
||||
} else if (iter->second.n_dims == 2 && !use_linear) {
|
||||
use_linear = true;
|
||||
blocks["q"] = std::make_shared<Linear>(in_channels, in_channels);
|
||||
blocks["k"] = std::make_shared<Linear>(in_channels, in_channels);
|
||||
blocks["v"] = std::make_shared<Linear>(in_channels, in_channels);
|
||||
blocks["proj_out"] = std::make_shared<Linear>(in_channels, in_channels);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
AttnBlock(int64_t in_channels, bool use_linear)
|
||||
: in_channels(in_channels), use_linear(use_linear) {
|
||||
blocks["norm"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
|
||||
if (use_linear) {
|
||||
blocks["q"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
|
||||
blocks["k"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
|
||||
blocks["v"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
|
||||
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, in_channels));
|
||||
} else {
|
||||
blocks["q"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
|
||||
blocks["k"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
|
||||
blocks["v"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
|
||||
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||
// x: [N, in_channels, h, w]
|
||||
auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]);
|
||||
auto q_proj = std::dynamic_pointer_cast<UnaryBlock>(blocks["q"]);
|
||||
auto k_proj = std::dynamic_pointer_cast<UnaryBlock>(blocks["k"]);
|
||||
auto v_proj = std::dynamic_pointer_cast<UnaryBlock>(blocks["v"]);
|
||||
auto proj_out = std::dynamic_pointer_cast<UnaryBlock>(blocks["proj_out"]);
|
||||
|
||||
auto h_ = norm->forward(ctx, x);
|
||||
|
||||
const int64_t n = h_->ne[3];
|
||||
const int64_t c = h_->ne[2];
|
||||
const int64_t h = h_->ne[1];
|
||||
const int64_t w = h_->ne[0];
|
||||
|
||||
ggml_tensor* q;
|
||||
ggml_tensor* k;
|
||||
ggml_tensor* v;
|
||||
if (use_linear) {
|
||||
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||
h_ = ggml_reshape_3d(ctx->ggml_ctx, h_, c, h * w, n); // [N, h * w, in_channels]
|
||||
|
||||
q = q_proj->forward(ctx, h_); // [N, h * w, in_channels]
|
||||
k = k_proj->forward(ctx, h_); // [N, h * w, in_channels]
|
||||
v = v_proj->forward(ctx, h_); // [N, h * w, in_channels]
|
||||
} else {
|
||||
q = q_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
||||
q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||
q = ggml_reshape_3d(ctx->ggml_ctx, q, c, h * w, n); // [N, h * w, in_channels]
|
||||
|
||||
k = k_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
||||
k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||
k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [N, h * w, in_channels]
|
||||
|
||||
v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
|
||||
v = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, v, 1, 2, 0, 3)); // [N, h, w, in_channels]
|
||||
v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels]
|
||||
}
|
||||
|
||||
h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, ctx->flash_attn_enabled);
|
||||
|
||||
if (use_linear) {
|
||||
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]
|
||||
|
||||
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
|
||||
h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w]
|
||||
} else {
|
||||
h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
|
||||
h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w]
|
||||
|
||||
h_ = proj_out->forward(ctx, h_); // [N, in_channels, h, w]
|
||||
}
|
||||
|
||||
h_ = ggml_add(ctx->ggml_ctx, h_, x);
|
||||
return h_;
|
||||
}
|
||||
};
|
||||
|
||||
class AE3DConv : public Conv2d {
|
||||
public:
|
||||
AE3DConv(int64_t in_channels,
|
||||
int64_t out_channels,
|
||||
std::pair<int, int> kernel_size,
|
||||
int video_kernel_size = 3,
|
||||
std::pair<int, int> stride = {1, 1},
|
||||
std::pair<int, int> padding = {0, 0},
|
||||
std::pair<int, int> dilation = {1, 1},
|
||||
bool bias = true)
|
||||
: Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias) {
|
||||
int kernel_padding = video_kernel_size / 2;
|
||||
blocks["time_mix_conv"] = std::shared_ptr<GGMLBlock>(new Conv3d(out_channels,
|
||||
out_channels,
|
||||
{video_kernel_size, 1, 1},
|
||||
{1, 1, 1},
|
||||
{kernel_padding, 0, 0}));
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
|
||||
struct ggml_tensor* x) override {
|
||||
// timesteps always None
|
||||
// skip_video always False
|
||||
// x: [N, IC, IH, IW]
|
||||
// result: [N, OC, OH, OW]
|
||||
auto time_mix_conv = std::dynamic_pointer_cast<Conv3d>(blocks["time_mix_conv"]);
|
||||
|
||||
x = Conv2d::forward(ctx, x);
|
||||
// timesteps = x.shape[0]
|
||||
// x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
||||
// x = conv3d(x)
|
||||
// return rearrange(x, "b c t h w -> (b t) c h w")
|
||||
int64_t T = x->ne[3];
|
||||
int64_t B = x->ne[3] / T;
|
||||
int64_t C = x->ne[2];
|
||||
int64_t H = x->ne[1];
|
||||
int64_t W = x->ne[0];
|
||||
|
||||
x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
|
||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
|
||||
x = time_mix_conv->forward(ctx, x); // [B, OC, T, OH * OW]
|
||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
|
||||
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
|
||||
return x; // [B*T, OC, OH, OW]
|
||||
}
|
||||
};
|
||||
|
||||
class VideoResnetBlock : public ResnetBlock {
|
||||
protected:
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override {
|
||||
enum ggml_type wtype = get_type(prefix + "mix_factor", tensor_storage_map, GGML_TYPE_F32);
|
||||
params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1);
|
||||
}
|
||||
|
||||
float get_alpha() {
|
||||
float alpha = ggml_ext_backend_tensor_get_f32(params["mix_factor"]);
|
||||
return sigmoid(alpha);
|
||||
}
|
||||
|
||||
public:
|
||||
VideoResnetBlock(int64_t in_channels,
|
||||
int64_t out_channels,
|
||||
int video_kernel_size = 3)
|
||||
: ResnetBlock(in_channels, out_channels) {
|
||||
// merge_strategy is always learned
|
||||
blocks["time_stack"] = std::shared_ptr<GGMLBlock>(new ResBlock(out_channels, 0, out_channels, {video_kernel_size, 1}, 3, false, true));
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||
// x: [N, in_channels, h, w] aka [b*t, in_channels, h, w]
|
||||
// return: [N, out_channels, h, w] aka [b*t, out_channels, h, w]
|
||||
// t_emb is always None
|
||||
// skip_video is always False
|
||||
// timesteps is always None
|
||||
auto time_stack = std::dynamic_pointer_cast<ResBlock>(blocks["time_stack"]);
|
||||
|
||||
x = ResnetBlock::forward(ctx, x); // [N, out_channels, h, w]
|
||||
// return x;
|
||||
|
||||
int64_t T = x->ne[3];
|
||||
int64_t B = x->ne[3] / T;
|
||||
int64_t C = x->ne[2];
|
||||
int64_t H = x->ne[1];
|
||||
int64_t W = x->ne[0];
|
||||
|
||||
x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
|
||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
|
||||
auto x_mix = x;
|
||||
|
||||
x = time_stack->forward(ctx, x); // b t c (h w)
|
||||
|
||||
float alpha = get_alpha();
|
||||
x = ggml_add(ctx->ggml_ctx,
|
||||
ggml_ext_scale(ctx->ggml_ctx, x, alpha),
|
||||
ggml_ext_scale(ctx->ggml_ctx, x_mix, 1.0f - alpha));
|
||||
|
||||
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
|
||||
x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
|
||||
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
// ldm.modules.diffusionmodules.model.Encoder
|
||||
class Encoder : public GGMLBlock {
|
||||
protected:
|
||||
int ch = 128;
|
||||
std::vector<int> ch_mult = {1, 2, 4, 4};
|
||||
int num_res_blocks = 2;
|
||||
int in_channels = 3;
|
||||
int z_channels = 4;
|
||||
bool double_z = true;
|
||||
|
||||
public:
|
||||
Encoder(int ch,
|
||||
std::vector<int> ch_mult,
|
||||
int num_res_blocks,
|
||||
int in_channels,
|
||||
int z_channels,
|
||||
bool double_z = true,
|
||||
bool use_linear_projection = false)
|
||||
: ch(ch),
|
||||
ch_mult(ch_mult),
|
||||
num_res_blocks(num_res_blocks),
|
||||
in_channels(in_channels),
|
||||
z_channels(z_channels),
|
||||
double_z(double_z) {
|
||||
blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, ch, {3, 3}, {1, 1}, {1, 1}));
|
||||
|
||||
size_t num_resolutions = ch_mult.size();
|
||||
|
||||
int block_in = 1;
|
||||
for (int i = 0; i < num_resolutions; i++) {
|
||||
if (i == 0) {
|
||||
block_in = ch;
|
||||
} else {
|
||||
block_in = ch * ch_mult[i - 1];
|
||||
}
|
||||
int block_out = ch * ch_mult[i];
|
||||
for (int j = 0; j < num_res_blocks; j++) {
|
||||
std::string name = "down." + std::to_string(i) + ".block." + std::to_string(j);
|
||||
blocks[name] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_out));
|
||||
block_in = block_out;
|
||||
}
|
||||
if (i != num_resolutions - 1) {
|
||||
std::string name = "down." + std::to_string(i) + ".downsample";
|
||||
blocks[name] = std::shared_ptr<GGMLBlock>(new DownSampleBlock(block_in, block_in, true));
|
||||
}
|
||||
}
|
||||
|
||||
blocks["mid.block_1"] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_in));
|
||||
blocks["mid.attn_1"] = std::shared_ptr<GGMLBlock>(new AttnBlock(block_in, use_linear_projection));
|
||||
blocks["mid.block_2"] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_in));
|
||||
|
||||
blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(block_in));
|
||||
blocks["conv_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(block_in, double_z ? z_channels * 2 : z_channels, {3, 3}, {1, 1}, {1, 1}));
|
||||
}
|
||||
|
||||
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||
// x: [N, in_channels, h, w]
|
||||
|
||||
auto conv_in = std::dynamic_pointer_cast<Conv2d>(blocks["conv_in"]);
|
||||
auto mid_block_1 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_1"]);
|
||||
auto mid_attn_1 = std::dynamic_pointer_cast<AttnBlock>(blocks["mid.attn_1"]);
|
||||
auto mid_block_2 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_2"]);
|
||||
auto norm_out = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm_out"]);
|
||||
auto conv_out = std::dynamic_pointer_cast<Conv2d>(blocks["conv_out"]);
|
||||
|
||||
auto h = conv_in->forward(ctx, x); // [N, ch, h, w]
|
||||
|
||||
// downsampling
|
||||
size_t num_resolutions = ch_mult.size();
|
||||
for (int i = 0; i < num_resolutions; i++) {
|
||||
for (int j = 0; j < num_res_blocks; j++) {
|
||||
std::string name = "down." + std::to_string(i) + ".block." + std::to_string(j);
|
||||
auto down_block = std::dynamic_pointer_cast<ResnetBlock>(blocks[name]);
|
||||
|
||||
h = down_block->forward(ctx, h);
|
||||
}
|
||||
if (i != num_resolutions - 1) {
|
||||
std::string name = "down." + std::to_string(i) + ".downsample";
|
||||
auto down_sample = std::dynamic_pointer_cast<DownSampleBlock>(blocks[name]);
|
||||
|
||||
h = down_sample->forward(ctx, h);
|
||||
}
|
||||
}
|
||||
|
||||
// middle
|
||||
h = mid_block_1->forward(ctx, h);
|
||||
h = mid_attn_1->forward(ctx, h);
|
||||
h = mid_block_2->forward(ctx, h); // [N, block_in, h, w]
|
||||
|
||||
// end
|
||||
h = norm_out->forward(ctx, h);
|
||||
h = ggml_silu_inplace(ctx->ggml_ctx, h); // nonlinearity/swish
|
||||
h = conv_out->forward(ctx, h); // [N, z_channels*2, h, w]
|
||||
return h;
|
||||
}
|
||||
};
|
||||
|
||||
// ldm.modules.diffusionmodules.model.Decoder
|
||||
class Decoder : public GGMLBlock {
|
||||
protected:
|
||||
int ch = 128;
|
||||
int out_ch = 3;
|
||||
std::vector<int> ch_mult = {1, 2, 4, 4};
|
||||
int num_res_blocks = 2;
|
||||
int z_channels = 4;
|
||||
bool video_decoder = false;
|
||||
int video_kernel_size = 3;
|
||||
|
||||
virtual std::shared_ptr<GGMLBlock> get_conv_out(int64_t in_channels,
|
||||
int64_t out_channels,
|
||||
std::pair<int, int> kernel_size,
|
||||
std::pair<int, int> stride = {1, 1},
|
||||
std::pair<int, int> padding = {0, 0}) {
|
||||
if (video_decoder) {
|
||||
return std::shared_ptr<GGMLBlock>(new AE3DConv(in_channels, out_channels, kernel_size, video_kernel_size, stride, padding));
|
||||
} else {
|
||||
return std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, kernel_size, stride, padding));
|
||||
}
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<GGMLBlock> get_resnet_block(int64_t in_channels,
|
||||
int64_t out_channels) {
|
||||
if (video_decoder) {
|
||||
return std::shared_ptr<GGMLBlock>(new VideoResnetBlock(in_channels, out_channels, video_kernel_size));
|
||||
} else {
|
||||
return std::shared_ptr<GGMLBlock>(new ResnetBlock(in_channels, out_channels));
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
Decoder(int ch,
|
||||
int out_ch,
|
||||
std::vector<int> ch_mult,
|
||||
int num_res_blocks,
|
||||
int z_channels,
|
||||
bool use_linear_projection = false,
|
||||
bool video_decoder = false,
|
||||
int video_kernel_size = 3)
|
||||
: ch(ch),
|
||||
out_ch(out_ch),
|
||||
ch_mult(ch_mult),
|
||||
num_res_blocks(num_res_blocks),
|
||||
z_channels(z_channels),
|
||||
video_decoder(video_decoder),
|
||||
video_kernel_size(video_kernel_size) {
|
||||
int num_resolutions = static_cast<int>(ch_mult.size());
|
||||
int block_in = ch * ch_mult[num_resolutions - 1];
|
||||
|
||||
blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, block_in, {3, 3}, {1, 1}, {1, 1}));
|
||||
|
||||
blocks["mid.block_1"] = get_resnet_block(block_in, block_in);
|
||||
blocks["mid.attn_1"] = std::shared_ptr<GGMLBlock>(new AttnBlock(block_in, use_linear_projection));
|
||||
blocks["mid.block_2"] = get_resnet_block(block_in, block_in);
|
||||
|
||||
for (int i = num_resolutions - 1; i >= 0; i--) {
|
||||
int mult = ch_mult[i];
|
||||
int block_out = ch * mult;
|
||||
for (int j = 0; j < num_res_blocks + 1; j++) {
|
||||
std::string name = "up." + std::to_string(i) + ".block." + std::to_string(j);
|
||||
blocks[name] = get_resnet_block(block_in, block_out);
|
||||
|
||||
block_in = block_out;
|
||||
}
|
||||
if (i != 0) {
|
||||
std::string name = "up." + std::to_string(i) + ".upsample";
|
||||
blocks[name] = std::shared_ptr<GGMLBlock>(new UpSampleBlock(block_in, block_in));
|
||||
}
|
||||
}
|
||||
|
||||
blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(block_in));
|
||||
blocks["conv_out"] = get_conv_out(block_in, out_ch, {3, 3}, {1, 1}, {1, 1});
|
||||
}
|
||||
|
||||
virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
|
||||
// z: [N, z_channels, h, w]
|
||||
// alpha is always 0
|
||||
// merge_strategy is always learned
|
||||
// time_mode is always conv-only, so we need to replace conv_out_op/resnet_op to AE3DConv/VideoResBlock
|
||||
// AttnVideoBlock will not be used
|
||||
auto conv_in = std::dynamic_pointer_cast<Conv2d>(blocks["conv_in"]);
|
||||
auto mid_block_1 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_1"]);
|
||||
auto mid_attn_1 = std::dynamic_pointer_cast<AttnBlock>(blocks["mid.attn_1"]);
|
||||
auto mid_block_2 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_2"]);
|
||||
auto norm_out = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm_out"]);
|
||||
auto conv_out = std::dynamic_pointer_cast<Conv2d>(blocks["conv_out"]);
|
||||
|
||||
// conv_in
|
||||
auto h = conv_in->forward(ctx, z); // [N, block_in, h, w]
|
||||
|
||||
// middle
|
||||
h = mid_block_1->forward(ctx, h);
|
||||
// return h;
|
||||
|
||||
h = mid_attn_1->forward(ctx, h);
|
||||
h = mid_block_2->forward(ctx, h); // [N, block_in, h, w]
|
||||
|
||||
// upsampling
|
||||
int num_resolutions = static_cast<int>(ch_mult.size());
|
||||
for (int i = num_resolutions - 1; i >= 0; i--) {
|
||||
for (int j = 0; j < num_res_blocks + 1; j++) {
|
||||
std::string name = "up." + std::to_string(i) + ".block." + std::to_string(j);
|
||||
auto up_block = std::dynamic_pointer_cast<ResnetBlock>(blocks[name]);
|
||||
|
||||
h = up_block->forward(ctx, h);
|
||||
}
|
||||
if (i != 0) {
|
||||
std::string name = "up." + std::to_string(i) + ".upsample";
|
||||
auto up_sample = std::dynamic_pointer_cast<UpSampleBlock>(blocks[name]);
|
||||
|
||||
h = up_sample->forward(ctx, h);
|
||||
}
|
||||
}
|
||||
|
||||
h = norm_out->forward(ctx, h);
|
||||
h = ggml_silu_inplace(ctx->ggml_ctx, h); // nonlinearity/swish
|
||||
h = conv_out->forward(ctx, h); // [N, out_ch, h*8, w*8]
|
||||
return h;
|
||||
}
|
||||
};
|
||||
|
||||
// ldm.models.autoencoder.AutoencoderKL
|
||||
class AutoencodingEngine : public GGMLBlock {
|
||||
struct VAE : public GGMLRunner {
|
||||
protected:
|
||||
SDVersion version;
|
||||
bool decode_only = true;
|
||||
bool use_video_decoder = false;
|
||||
bool use_quant = true;
|
||||
int embed_dim = 4;
|
||||
struct {
|
||||
int z_channels = 4;
|
||||
int resolution = 256;
|
||||
int in_channels = 3;
|
||||
int out_ch = 3;
|
||||
int ch = 128;
|
||||
std::vector<int> ch_mult = {1, 2, 4, 4};
|
||||
int num_res_blocks = 2;
|
||||
bool double_z = true;
|
||||
} dd_config;
|
||||
|
||||
public:
|
||||
AutoencodingEngine(SDVersion version = VERSION_SD1,
|
||||
bool decode_only = true,
|
||||
bool use_linear_projection = false,
|
||||
bool use_video_decoder = false)
|
||||
: version(version), decode_only(decode_only), use_video_decoder(use_video_decoder) {
|
||||
if (sd_version_is_dit(version)) {
|
||||
if (sd_version_is_flux2(version)) {
|
||||
dd_config.z_channels = 32;
|
||||
embed_dim = 32;
|
||||
} else {
|
||||
use_quant = false;
|
||||
dd_config.z_channels = 16;
|
||||
}
|
||||
}
|
||||
if (use_video_decoder) {
|
||||
use_quant = false;
|
||||
}
|
||||
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder(dd_config.ch,
|
||||
dd_config.out_ch,
|
||||
dd_config.ch_mult,
|
||||
dd_config.num_res_blocks,
|
||||
dd_config.z_channels,
|
||||
use_linear_projection,
|
||||
use_video_decoder));
|
||||
if (use_quant) {
|
||||
blocks["post_quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(dd_config.z_channels,
|
||||
embed_dim,
|
||||
{1, 1}));
|
||||
}
|
||||
if (!decode_only) {
|
||||
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new Encoder(dd_config.ch,
|
||||
dd_config.ch_mult,
|
||||
dd_config.num_res_blocks,
|
||||
dd_config.in_channels,
|
||||
dd_config.z_channels,
|
||||
dd_config.double_z,
|
||||
use_linear_projection));
|
||||
if (use_quant) {
|
||||
int factor = dd_config.double_z ? 2 : 1;
|
||||
|
||||
blocks["quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(embed_dim * factor,
|
||||
dd_config.z_channels * factor,
|
||||
{1, 1}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
|
||||
// z: [N, z_channels, h, w]
|
||||
if (sd_version_is_flux2(version)) {
|
||||
// [N, C*p*p, h, w] -> [N, C, h*p, w*p]
|
||||
int64_t p = 2;
|
||||
|
||||
int64_t N = z->ne[3];
|
||||
int64_t C = z->ne[2] / p / p;
|
||||
int64_t h = z->ne[1];
|
||||
int64_t w = z->ne[0];
|
||||
int64_t H = h * p;
|
||||
int64_t W = w * p;
|
||||
|
||||
z = ggml_reshape_4d(ctx->ggml_ctx, z, w * h, p * p, C, N); // [N, C, p*p, h*w]
|
||||
z = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, z, 1, 0, 2, 3)); // [N, C, h*w, p*p]
|
||||
z = ggml_reshape_4d(ctx->ggml_ctx, z, p, p, w, h * C * N); // [N*C*h, w, p, p]
|
||||
z = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, z, 0, 2, 1, 3)); // [N*C*h, p, w, p]
|
||||
z = ggml_reshape_4d(ctx->ggml_ctx, z, W, H, C, N); // [N, C, h*p, w*p]
|
||||
}
|
||||
|
||||
if (use_quant) {
|
||||
auto post_quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["post_quant_conv"]);
|
||||
z = post_quant_conv->forward(ctx, z); // [N, z_channels, h, w]
|
||||
}
|
||||
auto decoder = std::dynamic_pointer_cast<Decoder>(blocks["decoder"]);
|
||||
|
||||
ggml_set_name(z, "bench-start");
|
||||
auto h = decoder->forward(ctx, z);
|
||||
ggml_set_name(h, "bench-end");
|
||||
return h;
|
||||
}
|
||||
|
||||
struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
|
||||
// x: [N, in_channels, h, w]
|
||||
auto encoder = std::dynamic_pointer_cast<Encoder>(blocks["encoder"]);
|
||||
|
||||
auto z = encoder->forward(ctx, x); // [N, 2*z_channels, h/8, w/8]
|
||||
if (use_quant) {
|
||||
auto quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["quant_conv"]);
|
||||
z = quant_conv->forward(ctx, z); // [N, 2*embed_dim, h/8, w/8]
|
||||
}
|
||||
if (sd_version_is_flux2(version)) {
|
||||
z = ggml_ext_chunk(ctx->ggml_ctx, z, 2, 2)[0];
|
||||
|
||||
// [N, C, H, W] -> [N, C*p*p, H/p, W/p]
|
||||
int64_t p = 2;
|
||||
int64_t N = z->ne[3];
|
||||
int64_t C = z->ne[2];
|
||||
int64_t H = z->ne[1];
|
||||
int64_t W = z->ne[0];
|
||||
int64_t h = H / p;
|
||||
int64_t w = W / p;
|
||||
|
||||
z = ggml_reshape_4d(ctx->ggml_ctx, z, p, w, p, h * C * N); // [N*C*h, p, w, p]
|
||||
z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 0, 2, 1, 3)); // [N*C*h, w, p, p]
|
||||
z = ggml_reshape_4d(ctx->ggml_ctx, z, p * p, w * h, C, N); // [N, C, h*w, p*p]
|
||||
z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 1, 0, 2, 3)); // [N, C, p*p, h*w]
|
||||
z = ggml_reshape_4d(ctx->ggml_ctx, z, w, h, p * p * C, N); // [N, C*p*p, h*w]
|
||||
}
|
||||
return z;
|
||||
}
|
||||
};
|
||||
|
||||
struct VAE : public GGMLRunner {
|
||||
VAE(ggml_backend_t backend, bool offload_params_to_cpu)
|
||||
: GGMLRunner(backend, offload_params_to_cpu) {}
|
||||
virtual bool compute(const int n_threads,
|
||||
bool scale_input = true;
|
||||
virtual bool _compute(const int n_threads,
|
||||
struct ggml_tensor* z,
|
||||
bool decode_graph,
|
||||
struct ggml_tensor** output,
|
||||
struct ggml_context* output_ctx) = 0;
|
||||
|
||||
public:
|
||||
VAE(SDVersion version, ggml_backend_t backend, bool offload_params_to_cpu)
|
||||
: version(version), GGMLRunner(backend, offload_params_to_cpu) {}
|
||||
|
||||
int get_scale_factor() {
|
||||
int scale_factor = 8;
|
||||
if (version == VERSION_WAN2_2_TI2V) {
|
||||
scale_factor = 16;
|
||||
} else if (sd_version_is_flux2(version)) {
|
||||
scale_factor = 16;
|
||||
} else if (version == VERSION_CHROMA_RADIANCE) {
|
||||
scale_factor = 1;
|
||||
}
|
||||
return scale_factor;
|
||||
}
|
||||
|
||||
virtual int get_encoder_output_channels(int input_channels) = 0;
|
||||
|
||||
void get_tile_sizes(int& tile_size_x,
|
||||
int& tile_size_y,
|
||||
float& tile_overlap,
|
||||
const sd_tiling_params_t& params,
|
||||
int64_t latent_x,
|
||||
int64_t latent_y,
|
||||
float encoding_factor = 1.0f) {
|
||||
tile_overlap = std::max(std::min(params.target_overlap, 0.5f), 0.0f);
|
||||
auto get_tile_size = [&](int requested_size, float factor, int64_t latent_size) {
|
||||
const int default_tile_size = 32;
|
||||
const int min_tile_dimension = 4;
|
||||
int tile_size = default_tile_size;
|
||||
// factor <= 1 means simple fraction of the latent dimension
|
||||
// factor > 1 means number of tiles across that dimension
|
||||
if (factor > 0.f) {
|
||||
if (factor > 1.0)
|
||||
factor = 1 / (factor - factor * tile_overlap + tile_overlap);
|
||||
tile_size = static_cast<int>(std::round(latent_size * factor));
|
||||
} else if (requested_size >= min_tile_dimension) {
|
||||
tile_size = requested_size;
|
||||
}
|
||||
tile_size = static_cast<int>(tile_size * encoding_factor);
|
||||
return std::max(std::min(tile_size, static_cast<int>(latent_size)), min_tile_dimension);
|
||||
};
|
||||
|
||||
tile_size_x = get_tile_size(params.tile_size_x, params.rel_size_x, latent_x);
|
||||
tile_size_y = get_tile_size(params.tile_size_y, params.rel_size_y, latent_y);
|
||||
}
|
||||
|
||||
ggml_tensor* encode(int n_threads,
|
||||
ggml_context* work_ctx,
|
||||
ggml_tensor* x,
|
||||
sd_tiling_params_t tiling_params,
|
||||
bool circular_x = false,
|
||||
bool circular_y = false) {
|
||||
int64_t t0 = ggml_time_ms();
|
||||
ggml_tensor* result = nullptr;
|
||||
const int scale_factor = get_scale_factor();
|
||||
int64_t W = x->ne[0] / scale_factor;
|
||||
int64_t H = x->ne[1] / scale_factor;
|
||||
int channel_dim = sd_version_is_wan(version) ? 3 : 2;
|
||||
int64_t C = get_encoder_output_channels(static_cast<int>(x->ne[channel_dim]));
|
||||
int64_t ne2;
|
||||
int64_t ne3;
|
||||
if (sd_version_is_wan(version)) {
|
||||
int64_t T = x->ne[2];
|
||||
ne2 = (T - 1) / 4 + 1;
|
||||
ne3 = C;
|
||||
} else {
|
||||
ne2 = C;
|
||||
ne3 = x->ne[3];
|
||||
}
|
||||
result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, ne2, ne3);
|
||||
|
||||
if (scale_input) {
|
||||
scale_to_minus1_1(x);
|
||||
}
|
||||
|
||||
if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
|
||||
x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]);
|
||||
}
|
||||
|
||||
if (tiling_params.enabled) {
|
||||
float tile_overlap;
|
||||
int tile_size_x, tile_size_y;
|
||||
// multiply tile size for encode to keep the compute buffer size consistent
|
||||
get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, tiling_params, W, H, 1.30539f);
|
||||
|
||||
LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
|
||||
|
||||
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
||||
return _compute(n_threads, in, false, &out, work_ctx);
|
||||
};
|
||||
sd_tiling_non_square(x, result, scale_factor, tile_size_x, tile_size_y, tile_overlap, circular_x, circular_y, on_tiling);
|
||||
} else {
|
||||
_compute(n_threads, x, false, &result, work_ctx);
|
||||
}
|
||||
free_compute_buffer();
|
||||
|
||||
int64_t t1 = ggml_time_ms();
|
||||
LOG_DEBUG("computing vae encode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
||||
return result;
|
||||
}
|
||||
|
||||
ggml_tensor* decode(int n_threads,
|
||||
ggml_context* work_ctx,
|
||||
ggml_tensor* x,
|
||||
sd_tiling_params_t tiling_params,
|
||||
bool decode_video = false,
|
||||
bool circular_x = false,
|
||||
bool circular_y = false,
|
||||
ggml_tensor* result = nullptr,
|
||||
bool silent = false) {
|
||||
const int scale_factor = get_scale_factor();
|
||||
int64_t W = x->ne[0] * scale_factor;
|
||||
int64_t H = x->ne[1] * scale_factor;
|
||||
int64_t C = 3;
|
||||
if (result == nullptr) {
|
||||
if (decode_video) {
|
||||
int64_t T = x->ne[2];
|
||||
if (sd_version_is_wan(version)) {
|
||||
T = ((T - 1) * 4) + 1;
|
||||
}
|
||||
result = ggml_new_tensor_4d(work_ctx,
|
||||
GGML_TYPE_F32,
|
||||
W,
|
||||
H,
|
||||
T,
|
||||
3);
|
||||
} else {
|
||||
result = ggml_new_tensor_4d(work_ctx,
|
||||
GGML_TYPE_F32,
|
||||
W,
|
||||
H,
|
||||
C,
|
||||
x->ne[3]);
|
||||
}
|
||||
}
|
||||
int64_t t0 = ggml_time_ms();
|
||||
if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
|
||||
x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]);
|
||||
}
|
||||
if (tiling_params.enabled) {
|
||||
float tile_overlap;
|
||||
int tile_size_x, tile_size_y;
|
||||
get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, tiling_params, x->ne[0], x->ne[1]);
|
||||
|
||||
if (!silent) {
|
||||
LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
|
||||
}
|
||||
|
||||
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
||||
return _compute(n_threads, in, true, &out, nullptr);
|
||||
};
|
||||
sd_tiling_non_square(x, result, scale_factor, tile_size_x, tile_size_y, tile_overlap, circular_x, circular_y, on_tiling, silent);
|
||||
} else {
|
||||
if (!_compute(n_threads, x, true, &result, work_ctx)) {
|
||||
LOG_ERROR("Failed to decode latetnts");
|
||||
free_compute_buffer();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
free_compute_buffer();
|
||||
if (scale_input) {
|
||||
scale_to_0_1(result);
|
||||
}
|
||||
int64_t t1 = ggml_time_ms();
|
||||
LOG_DEBUG("computing vae decode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
||||
ggml_ext_tensor_clamp_inplace(result, 0.0f, 1.0f);
|
||||
return result;
|
||||
}
|
||||
|
||||
virtual ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr<RNG> rng) = 0;
|
||||
virtual ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) = 0;
|
||||
virtual ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) = 0;
|
||||
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) = 0;
|
||||
virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); };
|
||||
};
|
||||
|
||||
struct FakeVAE : public VAE {
|
||||
FakeVAE(ggml_backend_t backend, bool offload_params_to_cpu)
|
||||
: VAE(backend, offload_params_to_cpu) {}
|
||||
bool compute(const int n_threads,
|
||||
FakeVAE(SDVersion version, ggml_backend_t backend, bool offload_params_to_cpu)
|
||||
: VAE(version, backend, offload_params_to_cpu) {}
|
||||
|
||||
int get_encoder_output_channels(int input_channels) {
|
||||
return input_channels;
|
||||
}
|
||||
|
||||
bool _compute(const int n_threads,
|
||||
struct ggml_tensor* z,
|
||||
bool decode_graph,
|
||||
struct ggml_tensor** output,
|
||||
@ -642,6 +213,18 @@ struct FakeVAE : public VAE {
|
||||
return true;
|
||||
}
|
||||
|
||||
ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr<RNG> rng) {
|
||||
return vae_output;
|
||||
}
|
||||
|
||||
ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
||||
return ggml_ext_dup_and_cpy_tensor(work_ctx, latents);
|
||||
}
|
||||
|
||||
ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
||||
return ggml_ext_dup_and_cpy_tensor(work_ctx, latents);
|
||||
}
|
||||
|
||||
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) override {}
|
||||
|
||||
std::string get_desc() override {
|
||||
@ -649,126 +232,4 @@ struct FakeVAE : public VAE {
|
||||
}
|
||||
};
|
||||
|
||||
struct AutoEncoderKL : public VAE {
|
||||
bool decode_only = true;
|
||||
AutoencodingEngine ae;
|
||||
|
||||
AutoEncoderKL(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2TensorStorage& tensor_storage_map,
|
||||
const std::string prefix,
|
||||
bool decode_only = false,
|
||||
bool use_video_decoder = false,
|
||||
SDVersion version = VERSION_SD1)
|
||||
: decode_only(decode_only), VAE(backend, offload_params_to_cpu) {
|
||||
bool use_linear_projection = false;
|
||||
for (const auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (!starts_with(name, prefix)) {
|
||||
continue;
|
||||
}
|
||||
if (ends_with(name, "attn_1.proj_out.weight")) {
|
||||
if (tensor_storage.n_dims == 2) {
|
||||
use_linear_projection = true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
ae = AutoencodingEngine(version, decode_only, use_linear_projection, use_video_decoder);
|
||||
ae.init(params_ctx, tensor_storage_map, prefix);
|
||||
}
|
||||
|
||||
void set_conv2d_scale(float scale) override {
|
||||
std::vector<GGMLBlock*> blocks;
|
||||
ae.get_all_blocks(blocks);
|
||||
for (auto block : blocks) {
|
||||
if (block->get_desc() == "Conv2d") {
|
||||
auto conv_block = (Conv2d*)block;
|
||||
conv_block->set_scale(scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
return "vae";
|
||||
}
|
||||
|
||||
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) override {
|
||||
ae.get_param_tensors(tensors, prefix);
|
||||
}
|
||||
|
||||
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
|
||||
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||
|
||||
z = to_backend(z);
|
||||
|
||||
auto runner_ctx = get_context();
|
||||
|
||||
struct ggml_tensor* out = decode_graph ? ae.decode(&runner_ctx, z) : ae.encode(&runner_ctx, z);
|
||||
|
||||
ggml_build_forward_expand(gf, out);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
bool compute(const int n_threads,
|
||||
struct ggml_tensor* z,
|
||||
bool decode_graph,
|
||||
struct ggml_tensor** output,
|
||||
struct ggml_context* output_ctx = nullptr) override {
|
||||
GGML_ASSERT(!decode_only || decode_graph);
|
||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||
return build_graph(z, decode_graph);
|
||||
};
|
||||
// ggml_set_f32(z, 0.5f);
|
||||
// print_ggml_tensor(z);
|
||||
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
}
|
||||
|
||||
void test() {
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
|
||||
params.mem_buffer = nullptr;
|
||||
params.no_alloc = false;
|
||||
|
||||
struct ggml_context* work_ctx = ggml_init(params);
|
||||
GGML_ASSERT(work_ctx != nullptr);
|
||||
|
||||
{
|
||||
// CPU, x{1, 3, 64, 64}: Pass
|
||||
// CUDA, x{1, 3, 64, 64}: Pass, but sill get wrong result for some image, may be due to interlnal nan
|
||||
// CPU, x{2, 3, 64, 64}: Wrong result
|
||||
// CUDA, x{2, 3, 64, 64}: Wrong result, and different from CPU result
|
||||
auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 64, 64, 3, 2);
|
||||
ggml_set_f32(x, 0.5f);
|
||||
print_ggml_tensor(x);
|
||||
struct ggml_tensor* out = nullptr;
|
||||
|
||||
int64_t t0 = ggml_time_ms();
|
||||
compute(8, x, false, &out, work_ctx);
|
||||
int64_t t1 = ggml_time_ms();
|
||||
|
||||
print_ggml_tensor(out);
|
||||
LOG_DEBUG("encode test done in %lldms", t1 - t0);
|
||||
}
|
||||
|
||||
if (false) {
|
||||
// CPU, z{1, 4, 8, 8}: Pass
|
||||
// CUDA, z{1, 4, 8, 8}: Pass
|
||||
// CPU, z{3, 4, 8, 8}: Wrong result
|
||||
// CUDA, z{3, 4, 8, 8}: Wrong result, and different from CPU result
|
||||
auto z = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 4, 1);
|
||||
ggml_set_f32(z, 0.5f);
|
||||
print_ggml_tensor(z);
|
||||
struct ggml_tensor* out = nullptr;
|
||||
|
||||
int64_t t0 = ggml_time_ms();
|
||||
compute(8, z, true, &out, work_ctx);
|
||||
int64_t t1 = ggml_time_ms();
|
||||
|
||||
print_ggml_tensor(out);
|
||||
LOG_DEBUG("decode test done in %lldms", t1 - t0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
#endif
|
||||
#endif // __VAE_HPP__
|
||||
|
||||
102
src/wan.hpp
102
src/wan.hpp
@ -1109,6 +1109,7 @@ namespace WAN {
|
||||
};
|
||||
|
||||
struct WanVAERunner : public VAE {
|
||||
float scale_factor = 1.0f;
|
||||
bool decode_only = true;
|
||||
WanVAE ae;
|
||||
|
||||
@ -1118,7 +1119,7 @@ namespace WAN {
|
||||
const std::string prefix = "",
|
||||
bool decode_only = false,
|
||||
SDVersion version = VERSION_WAN2)
|
||||
: decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V), VAE(backend, offload_params_to_cpu) {
|
||||
: decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V), VAE(version, backend, offload_params_to_cpu) {
|
||||
ae.init(params_ctx, tensor_storage_map, prefix);
|
||||
}
|
||||
|
||||
@ -1130,6 +1131,101 @@ namespace WAN {
|
||||
ae.get_param_tensors(tensors, prefix);
|
||||
}
|
||||
|
||||
ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr<RNG> rng) {
|
||||
return vae_output;
|
||||
}
|
||||
|
||||
void get_latents_mean_std_vec(ggml_tensor* latents, int channel_dim, std::vector<float>& latents_mean_vec, std::vector<float>& latents_std_vec) {
|
||||
GGML_ASSERT(latents->ne[channel_dim] == 16 || latents->ne[channel_dim] == 48);
|
||||
if (latents->ne[channel_dim] == 16) { // Wan2.1 VAE
|
||||
latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f,
|
||||
0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f};
|
||||
latents_std_vec = {2.8184f, 1.4541f, 2.3275f, 2.6558f, 1.2196f, 1.7708f, 2.6052f, 2.0743f,
|
||||
3.2687f, 2.1526f, 2.8652f, 1.5579f, 1.6382f, 1.1253f, 2.8251f, 1.9160f};
|
||||
} else if (latents->ne[channel_dim] == 48) { // Wan2.2 VAE
|
||||
latents_mean_vec = {-0.2289f, -0.0052f, -0.1323f, -0.2339f, -0.2799f, 0.0174f, 0.1838f, 0.1557f,
|
||||
-0.1382f, 0.0542f, 0.2813f, 0.0891f, 0.1570f, -0.0098f, 0.0375f, -0.1825f,
|
||||
-0.2246f, -0.1207f, -0.0698f, 0.5109f, 0.2665f, -0.2108f, -0.2158f, 0.2502f,
|
||||
-0.2055f, -0.0322f, 0.1109f, 0.1567f, -0.0729f, 0.0899f, -0.2799f, -0.1230f,
|
||||
-0.0313f, -0.1649f, 0.0117f, 0.0723f, -0.2839f, -0.2083f, -0.0520f, 0.3748f,
|
||||
0.0152f, 0.1957f, 0.1433f, -0.2944f, 0.3573f, -0.0548f, -0.1681f, -0.0667f};
|
||||
latents_std_vec = {
|
||||
0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f,
|
||||
0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f,
|
||||
0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f,
|
||||
0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f,
|
||||
0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f,
|
||||
0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f};
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
||||
ggml_tensor* vae_latents = ggml_dup(work_ctx, latents);
|
||||
int channel_dim = sd_version_is_wan(version) ? 3 : 2;
|
||||
std::vector<float> latents_mean_vec;
|
||||
std::vector<float> latents_std_vec;
|
||||
get_latents_mean_std_vec(latents, channel_dim, latents_mean_vec, latents_std_vec);
|
||||
|
||||
float mean;
|
||||
float std_;
|
||||
for (int i = 0; i < latents->ne[3]; i++) {
|
||||
if (channel_dim == 3) {
|
||||
mean = latents_mean_vec[i];
|
||||
std_ = latents_std_vec[i];
|
||||
}
|
||||
for (int j = 0; j < latents->ne[2]; j++) {
|
||||
if (channel_dim == 2) {
|
||||
mean = latents_mean_vec[j];
|
||||
std_ = latents_std_vec[j];
|
||||
}
|
||||
for (int k = 0; k < latents->ne[1]; k++) {
|
||||
for (int l = 0; l < latents->ne[0]; l++) {
|
||||
float value = ggml_ext_tensor_get_f32(latents, l, k, j, i);
|
||||
value = value * std_ / scale_factor + mean;
|
||||
ggml_ext_tensor_set_f32(vae_latents, value, l, k, j, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return vae_latents;
|
||||
}
|
||||
|
||||
ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
||||
ggml_tensor* diffusion_latents = ggml_dup(work_ctx, latents);
|
||||
int channel_dim = sd_version_is_wan(version) ? 3 : 2;
|
||||
std::vector<float> latents_mean_vec;
|
||||
std::vector<float> latents_std_vec;
|
||||
get_latents_mean_std_vec(latents, channel_dim, latents_mean_vec, latents_std_vec);
|
||||
|
||||
float mean;
|
||||
float std_;
|
||||
for (int i = 0; i < latents->ne[3]; i++) {
|
||||
if (channel_dim == 3) {
|
||||
mean = latents_mean_vec[i];
|
||||
std_ = latents_std_vec[i];
|
||||
}
|
||||
for (int j = 0; j < latents->ne[2]; j++) {
|
||||
if (channel_dim == 2) {
|
||||
mean = latents_mean_vec[j];
|
||||
std_ = latents_std_vec[j];
|
||||
}
|
||||
for (int k = 0; k < latents->ne[1]; k++) {
|
||||
for (int l = 0; l < latents->ne[0]; l++) {
|
||||
float value = ggml_ext_tensor_get_f32(latents, l, k, j, i);
|
||||
value = (value - mean) * scale_factor / std_;
|
||||
ggml_ext_tensor_set_f32(diffusion_latents, value, l, k, j, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return diffusion_latents;
|
||||
}
|
||||
|
||||
int get_encoder_output_channels(int input_channels) {
|
||||
return static_cast<int>(ae.z_dim);
|
||||
}
|
||||
|
||||
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
|
||||
struct ggml_cgraph* gf = new_graph_custom(10240 * z->ne[2]);
|
||||
|
||||
@ -1173,7 +1269,7 @@ namespace WAN {
|
||||
return gf;
|
||||
}
|
||||
|
||||
bool compute(const int n_threads,
|
||||
bool _compute(const int n_threads,
|
||||
struct ggml_tensor* z,
|
||||
bool decode_graph,
|
||||
struct ggml_tensor** output,
|
||||
@ -1249,7 +1345,7 @@ namespace WAN {
|
||||
struct ggml_tensor* out = nullptr;
|
||||
|
||||
int64_t t0 = ggml_time_ms();
|
||||
compute(8, z, true, &out, work_ctx);
|
||||
_compute(8, z, true, &out, work_ctx);
|
||||
int64_t t1 = ggml_time_ms();
|
||||
|
||||
print_ggml_tensor(out);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user