From 9371620737fba53cab9873a6ff66a9c12db89964 Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 15 Mar 2026 16:03:31 +0800 Subject: [PATCH] refactor: optimize the VAE architecture --- src/auto_encoder_kl.hpp | 930 ++++++++++++++++++++++++++++++++++++++ src/ggml_extend.hpp | 35 +- src/model.cpp | 6 +- src/name_conversion.cpp | 6 +- src/stable-diffusion.cpp | 637 +++++---------------------- src/tae.hpp | 139 +++--- src/vae.hpp | 931 +++++++++------------------------------ src/wan.hpp | 112 ++++- 8 files changed, 1437 insertions(+), 1359 deletions(-) create mode 100644 src/auto_encoder_kl.hpp diff --git a/src/auto_encoder_kl.hpp b/src/auto_encoder_kl.hpp new file mode 100644 index 0000000..581bc59 --- /dev/null +++ b/src/auto_encoder_kl.hpp @@ -0,0 +1,930 @@ +#ifndef __AUTO_ENCODER_KL_HPP__ +#define __AUTO_ENCODER_KL_HPP__ + +#include "vae.hpp" + +/*================================================== AutoEncoderKL ===================================================*/ + +#define VAE_GRAPH_SIZE 20480 + +class ResnetBlock : public UnaryBlock { +protected: + int64_t in_channels; + int64_t out_channels; + +public: + ResnetBlock(int64_t in_channels, + int64_t out_channels) + : in_channels(in_channels), + out_channels(out_channels) { + // temb_channels is always 0 + blocks["norm1"] = std::shared_ptr(new GroupNorm32(in_channels)); + blocks["conv1"] = std::shared_ptr(new Conv2d(in_channels, out_channels, {3, 3}, {1, 1}, {1, 1})); + + blocks["norm2"] = std::shared_ptr(new GroupNorm32(out_channels)); + blocks["conv2"] = std::shared_ptr(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1})); + + if (out_channels != in_channels) { + blocks["nin_shortcut"] = std::shared_ptr(new Conv2d(in_channels, out_channels, {1, 1})); + } + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { + // x: [N, in_channels, h, w] + // t_emb is always None + auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); + auto conv1 = std::dynamic_pointer_cast(blocks["conv1"]); + auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]); + auto conv2 = std::dynamic_pointer_cast(blocks["conv2"]); + + auto h = x; + h = norm1->forward(ctx, h); + h = ggml_silu_inplace(ctx->ggml_ctx, h); // swish + h = conv1->forward(ctx, h); + // return h; + + h = norm2->forward(ctx, h); + h = ggml_silu_inplace(ctx->ggml_ctx, h); // swish + // dropout, skip for inference + h = conv2->forward(ctx, h); + + // skip connection + if (out_channels != in_channels) { + auto nin_shortcut = std::dynamic_pointer_cast(blocks["nin_shortcut"]); + + x = nin_shortcut->forward(ctx, x); // [N, out_channels, h, w] + } + + h = ggml_add(ctx->ggml_ctx, h, x); + return h; // [N, out_channels, h, w] + } +}; + +class AttnBlock : public UnaryBlock { +protected: + int64_t in_channels; + bool use_linear; + + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") { + auto iter = tensor_storage_map.find(prefix + "proj_out.weight"); + if (iter != tensor_storage_map.end()) { + if (iter->second.n_dims == 4 && use_linear) { + use_linear = false; + blocks["q"] = std::make_shared(in_channels, in_channels, std::pair{1, 1}); + blocks["k"] = std::make_shared(in_channels, in_channels, std::pair{1, 1}); + blocks["v"] = std::make_shared(in_channels, in_channels, std::pair{1, 1}); + blocks["proj_out"] = std::make_shared(in_channels, in_channels, std::pair{1, 1}); + } else if (iter->second.n_dims == 2 && !use_linear) { + use_linear = true; + blocks["q"] = std::make_shared(in_channels, in_channels); + blocks["k"] = std::make_shared(in_channels, in_channels); + blocks["v"] = std::make_shared(in_channels, in_channels); + blocks["proj_out"] = std::make_shared(in_channels, in_channels); + } + } + } + +public: + AttnBlock(int64_t in_channels, bool use_linear) + : in_channels(in_channels), use_linear(use_linear) { + blocks["norm"] = std::shared_ptr(new GroupNorm32(in_channels)); + if (use_linear) { + blocks["q"] = std::shared_ptr(new Linear(in_channels, in_channels)); + blocks["k"] = std::shared_ptr(new Linear(in_channels, in_channels)); + blocks["v"] = std::shared_ptr(new Linear(in_channels, in_channels)); + blocks["proj_out"] = std::shared_ptr(new Linear(in_channels, in_channels)); + } else { + blocks["q"] = std::shared_ptr(new Conv2d(in_channels, in_channels, {1, 1})); + blocks["k"] = std::shared_ptr(new Conv2d(in_channels, in_channels, {1, 1})); + blocks["v"] = std::shared_ptr(new Conv2d(in_channels, in_channels, {1, 1})); + blocks["proj_out"] = std::shared_ptr(new Conv2d(in_channels, in_channels, {1, 1})); + } + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { + // x: [N, in_channels, h, w] + auto norm = std::dynamic_pointer_cast(blocks["norm"]); + auto q_proj = std::dynamic_pointer_cast(blocks["q"]); + auto k_proj = std::dynamic_pointer_cast(blocks["k"]); + auto v_proj = std::dynamic_pointer_cast(blocks["v"]); + auto proj_out = std::dynamic_pointer_cast(blocks["proj_out"]); + + auto h_ = norm->forward(ctx, x); + + const int64_t n = h_->ne[3]; + const int64_t c = h_->ne[2]; + const int64_t h = h_->ne[1]; + const int64_t w = h_->ne[0]; + + ggml_tensor* q; + ggml_tensor* k; + ggml_tensor* v; + if (use_linear) { + h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 2, 0, 3)); // [N, h, w, in_channels] + h_ = ggml_reshape_3d(ctx->ggml_ctx, h_, c, h * w, n); // [N, h * w, in_channels] + + q = q_proj->forward(ctx, h_); // [N, h * w, in_channels] + k = k_proj->forward(ctx, h_); // [N, h * w, in_channels] + v = v_proj->forward(ctx, h_); // [N, h * w, in_channels] + } else { + q = q_proj->forward(ctx, h_); // [N, in_channels, h, w] + q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels] + q = ggml_reshape_3d(ctx->ggml_ctx, q, c, h * w, n); // [N, h * w, in_channels] + + k = k_proj->forward(ctx, h_); // [N, in_channels, h, w] + k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels] + k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [N, h * w, in_channels] + + v = v_proj->forward(ctx, h_); // [N, in_channels, h, w] + v = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, v, 1, 2, 0, 3)); // [N, h, w, in_channels] + v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels] + } + + h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, ctx->flash_attn_enabled); + + if (use_linear) { + h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels] + + h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w] + h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w] + } else { + h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w] + h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w] + + h_ = proj_out->forward(ctx, h_); // [N, in_channels, h, w] + } + + h_ = ggml_add(ctx->ggml_ctx, h_, x); + return h_; + } +}; + +class AE3DConv : public Conv2d { +public: + AE3DConv(int64_t in_channels, + int64_t out_channels, + std::pair kernel_size, + int video_kernel_size = 3, + std::pair stride = {1, 1}, + std::pair padding = {0, 0}, + std::pair dilation = {1, 1}, + bool bias = true) + : Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias) { + int kernel_padding = video_kernel_size / 2; + blocks["time_mix_conv"] = std::shared_ptr(new Conv3d(out_channels, + out_channels, + {video_kernel_size, 1, 1}, + {1, 1, 1}, + {kernel_padding, 0, 0})); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, + struct ggml_tensor* x) override { + // timesteps always None + // skip_video always False + // x: [N, IC, IH, IW] + // result: [N, OC, OH, OW] + auto time_mix_conv = std::dynamic_pointer_cast(blocks["time_mix_conv"]); + + x = Conv2d::forward(ctx, x); + // timesteps = x.shape[0] + // x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + // x = conv3d(x) + // return rearrange(x, "b c t h w -> (b t) c h w") + int64_t T = x->ne[3]; + int64_t B = x->ne[3] / T; + int64_t C = x->ne[2]; + int64_t H = x->ne[1]; + int64_t W = x->ne[0]; + + x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w) + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w) + x = time_mix_conv->forward(ctx, x); // [B, OC, T, OH * OW] + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w) + x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w + return x; // [B*T, OC, OH, OW] + } +}; + +class VideoResnetBlock : public ResnetBlock { +protected: + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { + enum ggml_type wtype = get_type(prefix + "mix_factor", tensor_storage_map, GGML_TYPE_F32); + params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1); + } + + float get_alpha() { + float alpha = ggml_ext_backend_tensor_get_f32(params["mix_factor"]); + return sigmoid(alpha); + } + +public: + VideoResnetBlock(int64_t in_channels, + int64_t out_channels, + int video_kernel_size = 3) + : ResnetBlock(in_channels, out_channels) { + // merge_strategy is always learned + blocks["time_stack"] = std::shared_ptr(new ResBlock(out_channels, 0, out_channels, {video_kernel_size, 1}, 3, false, true)); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { + // x: [N, in_channels, h, w] aka [b*t, in_channels, h, w] + // return: [N, out_channels, h, w] aka [b*t, out_channels, h, w] + // t_emb is always None + // skip_video is always False + // timesteps is always None + auto time_stack = std::dynamic_pointer_cast(blocks["time_stack"]); + + x = ResnetBlock::forward(ctx, x); // [N, out_channels, h, w] + // return x; + + int64_t T = x->ne[3]; + int64_t B = x->ne[3] / T; + int64_t C = x->ne[2]; + int64_t H = x->ne[1]; + int64_t W = x->ne[0]; + + x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w) + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w) + auto x_mix = x; + + x = time_stack->forward(ctx, x); // b t c (h w) + + float alpha = get_alpha(); + x = ggml_add(ctx->ggml_ctx, + ggml_ext_scale(ctx->ggml_ctx, x, alpha), + ggml_ext_scale(ctx->ggml_ctx, x_mix, 1.0f - alpha)); + + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w) + x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w + + return x; + } +}; + +// ldm.modules.diffusionmodules.model.Encoder +class Encoder : public GGMLBlock { +protected: + int ch = 128; + std::vector ch_mult = {1, 2, 4, 4}; + int num_res_blocks = 2; + int in_channels = 3; + int z_channels = 4; + bool double_z = true; + +public: + Encoder(int ch, + std::vector ch_mult, + int num_res_blocks, + int in_channels, + int z_channels, + bool double_z = true, + bool use_linear_projection = false) + : ch(ch), + ch_mult(ch_mult), + num_res_blocks(num_res_blocks), + in_channels(in_channels), + z_channels(z_channels), + double_z(double_z) { + blocks["conv_in"] = std::shared_ptr(new Conv2d(in_channels, ch, {3, 3}, {1, 1}, {1, 1})); + + size_t num_resolutions = ch_mult.size(); + + int block_in = 1; + for (int i = 0; i < num_resolutions; i++) { + if (i == 0) { + block_in = ch; + } else { + block_in = ch * ch_mult[i - 1]; + } + int block_out = ch * ch_mult[i]; + for (int j = 0; j < num_res_blocks; j++) { + std::string name = "down." + std::to_string(i) + ".block." + std::to_string(j); + blocks[name] = std::shared_ptr(new ResnetBlock(block_in, block_out)); + block_in = block_out; + } + if (i != num_resolutions - 1) { + std::string name = "down." + std::to_string(i) + ".downsample"; + blocks[name] = std::shared_ptr(new DownSampleBlock(block_in, block_in, true)); + } + } + + blocks["mid.block_1"] = std::shared_ptr(new ResnetBlock(block_in, block_in)); + blocks["mid.attn_1"] = std::shared_ptr(new AttnBlock(block_in, use_linear_projection)); + blocks["mid.block_2"] = std::shared_ptr(new ResnetBlock(block_in, block_in)); + + blocks["norm_out"] = std::shared_ptr(new GroupNorm32(block_in)); + blocks["conv_out"] = std::shared_ptr(new Conv2d(block_in, double_z ? z_channels * 2 : z_channels, {3, 3}, {1, 1}, {1, 1})); + } + + virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { + // x: [N, in_channels, h, w] + + auto conv_in = std::dynamic_pointer_cast(blocks["conv_in"]); + auto mid_block_1 = std::dynamic_pointer_cast(blocks["mid.block_1"]); + auto mid_attn_1 = std::dynamic_pointer_cast(blocks["mid.attn_1"]); + auto mid_block_2 = std::dynamic_pointer_cast(blocks["mid.block_2"]); + auto norm_out = std::dynamic_pointer_cast(blocks["norm_out"]); + auto conv_out = std::dynamic_pointer_cast(blocks["conv_out"]); + + auto h = conv_in->forward(ctx, x); // [N, ch, h, w] + + // downsampling + size_t num_resolutions = ch_mult.size(); + for (int i = 0; i < num_resolutions; i++) { + for (int j = 0; j < num_res_blocks; j++) { + std::string name = "down." + std::to_string(i) + ".block." + std::to_string(j); + auto down_block = std::dynamic_pointer_cast(blocks[name]); + + h = down_block->forward(ctx, h); + } + if (i != num_resolutions - 1) { + std::string name = "down." + std::to_string(i) + ".downsample"; + auto down_sample = std::dynamic_pointer_cast(blocks[name]); + + h = down_sample->forward(ctx, h); + } + } + + // middle + h = mid_block_1->forward(ctx, h); + h = mid_attn_1->forward(ctx, h); + h = mid_block_2->forward(ctx, h); // [N, block_in, h, w] + + // end + h = norm_out->forward(ctx, h); + h = ggml_silu_inplace(ctx->ggml_ctx, h); // nonlinearity/swish + h = conv_out->forward(ctx, h); // [N, z_channels*2, h, w] + return h; + } +}; + +// ldm.modules.diffusionmodules.model.Decoder +class Decoder : public GGMLBlock { +protected: + int ch = 128; + int out_ch = 3; + std::vector ch_mult = {1, 2, 4, 4}; + int num_res_blocks = 2; + int z_channels = 4; + bool video_decoder = false; + int video_kernel_size = 3; + + virtual std::shared_ptr get_conv_out(int64_t in_channels, + int64_t out_channels, + std::pair kernel_size, + std::pair stride = {1, 1}, + std::pair padding = {0, 0}) { + if (video_decoder) { + return std::shared_ptr(new AE3DConv(in_channels, out_channels, kernel_size, video_kernel_size, stride, padding)); + } else { + return std::shared_ptr(new Conv2d(in_channels, out_channels, kernel_size, stride, padding)); + } + } + + virtual std::shared_ptr get_resnet_block(int64_t in_channels, + int64_t out_channels) { + if (video_decoder) { + return std::shared_ptr(new VideoResnetBlock(in_channels, out_channels, video_kernel_size)); + } else { + return std::shared_ptr(new ResnetBlock(in_channels, out_channels)); + } + } + +public: + Decoder(int ch, + int out_ch, + std::vector ch_mult, + int num_res_blocks, + int z_channels, + bool use_linear_projection = false, + bool video_decoder = false, + int video_kernel_size = 3) + : ch(ch), + out_ch(out_ch), + ch_mult(ch_mult), + num_res_blocks(num_res_blocks), + z_channels(z_channels), + video_decoder(video_decoder), + video_kernel_size(video_kernel_size) { + int num_resolutions = static_cast(ch_mult.size()); + int block_in = ch * ch_mult[num_resolutions - 1]; + + blocks["conv_in"] = std::shared_ptr(new Conv2d(z_channels, block_in, {3, 3}, {1, 1}, {1, 1})); + + blocks["mid.block_1"] = get_resnet_block(block_in, block_in); + blocks["mid.attn_1"] = std::shared_ptr(new AttnBlock(block_in, use_linear_projection)); + blocks["mid.block_2"] = get_resnet_block(block_in, block_in); + + for (int i = num_resolutions - 1; i >= 0; i--) { + int mult = ch_mult[i]; + int block_out = ch * mult; + for (int j = 0; j < num_res_blocks + 1; j++) { + std::string name = "up." + std::to_string(i) + ".block." + std::to_string(j); + blocks[name] = get_resnet_block(block_in, block_out); + + block_in = block_out; + } + if (i != 0) { + std::string name = "up." + std::to_string(i) + ".upsample"; + blocks[name] = std::shared_ptr(new UpSampleBlock(block_in, block_in)); + } + } + + blocks["norm_out"] = std::shared_ptr(new GroupNorm32(block_in)); + blocks["conv_out"] = get_conv_out(block_in, out_ch, {3, 3}, {1, 1}, {1, 1}); + } + + virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) { + // z: [N, z_channels, h, w] + // alpha is always 0 + // merge_strategy is always learned + // time_mode is always conv-only, so we need to replace conv_out_op/resnet_op to AE3DConv/VideoResBlock + // AttnVideoBlock will not be used + auto conv_in = std::dynamic_pointer_cast(blocks["conv_in"]); + auto mid_block_1 = std::dynamic_pointer_cast(blocks["mid.block_1"]); + auto mid_attn_1 = std::dynamic_pointer_cast(blocks["mid.attn_1"]); + auto mid_block_2 = std::dynamic_pointer_cast(blocks["mid.block_2"]); + auto norm_out = std::dynamic_pointer_cast(blocks["norm_out"]); + auto conv_out = std::dynamic_pointer_cast(blocks["conv_out"]); + + // conv_in + auto h = conv_in->forward(ctx, z); // [N, block_in, h, w] + + // middle + h = mid_block_1->forward(ctx, h); + // return h; + + h = mid_attn_1->forward(ctx, h); + h = mid_block_2->forward(ctx, h); // [N, block_in, h, w] + + // upsampling + int num_resolutions = static_cast(ch_mult.size()); + for (int i = num_resolutions - 1; i >= 0; i--) { + for (int j = 0; j < num_res_blocks + 1; j++) { + std::string name = "up." + std::to_string(i) + ".block." + std::to_string(j); + auto up_block = std::dynamic_pointer_cast(blocks[name]); + + h = up_block->forward(ctx, h); + } + if (i != 0) { + std::string name = "up." + std::to_string(i) + ".upsample"; + auto up_sample = std::dynamic_pointer_cast(blocks[name]); + + h = up_sample->forward(ctx, h); + } + } + + h = norm_out->forward(ctx, h); + h = ggml_silu_inplace(ctx->ggml_ctx, h); // nonlinearity/swish + h = conv_out->forward(ctx, h); // [N, out_ch, h*8, w*8] + return h; + } +}; + +// ldm.models.autoencoder.AutoencoderKL +class AutoEncoderKLModel : public GGMLBlock { +protected: + SDVersion version; + bool decode_only = true; + bool use_video_decoder = false; + bool use_quant = true; + int embed_dim = 4; + struct { + int z_channels = 4; + int resolution = 256; + int in_channels = 3; + int out_ch = 3; + int ch = 128; + std::vector ch_mult = {1, 2, 4, 4}; + int num_res_blocks = 2; + bool double_z = true; + } dd_config; + +public: + AutoEncoderKLModel(SDVersion version = VERSION_SD1, + bool decode_only = true, + bool use_linear_projection = false, + bool use_video_decoder = false) + : version(version), decode_only(decode_only), use_video_decoder(use_video_decoder) { + if (sd_version_is_dit(version)) { + if (sd_version_is_flux2(version)) { + dd_config.z_channels = 32; + embed_dim = 32; + } else { + use_quant = false; + dd_config.z_channels = 16; + } + } + if (use_video_decoder) { + use_quant = false; + } + blocks["decoder"] = std::shared_ptr(new Decoder(dd_config.ch, + dd_config.out_ch, + dd_config.ch_mult, + dd_config.num_res_blocks, + dd_config.z_channels, + use_linear_projection, + use_video_decoder)); + if (use_quant) { + blocks["post_quant_conv"] = std::shared_ptr(new Conv2d(dd_config.z_channels, + embed_dim, + {1, 1})); + } + if (!decode_only) { + blocks["encoder"] = std::shared_ptr(new Encoder(dd_config.ch, + dd_config.ch_mult, + dd_config.num_res_blocks, + dd_config.in_channels, + dd_config.z_channels, + dd_config.double_z, + use_linear_projection)); + if (use_quant) { + int factor = dd_config.double_z ? 2 : 1; + + blocks["quant_conv"] = std::shared_ptr(new Conv2d(embed_dim * factor, + dd_config.z_channels * factor, + {1, 1})); + } + } + } + + struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) { + // z: [N, z_channels, h, w] + if (sd_version_is_flux2(version)) { + // [N, C*p*p, h, w] -> [N, C, h*p, w*p] + int64_t p = 2; + + int64_t N = z->ne[3]; + int64_t C = z->ne[2] / p / p; + int64_t h = z->ne[1]; + int64_t w = z->ne[0]; + int64_t H = h * p; + int64_t W = w * p; + + z = ggml_reshape_4d(ctx->ggml_ctx, z, w * h, p * p, C, N); // [N, C, p*p, h*w] + z = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, z, 1, 0, 2, 3)); // [N, C, h*w, p*p] + z = ggml_reshape_4d(ctx->ggml_ctx, z, p, p, w, h * C * N); // [N*C*h, w, p, p] + z = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, z, 0, 2, 1, 3)); // [N*C*h, p, w, p] + z = ggml_reshape_4d(ctx->ggml_ctx, z, W, H, C, N); // [N, C, h*p, w*p] + } + + if (use_quant) { + auto post_quant_conv = std::dynamic_pointer_cast(blocks["post_quant_conv"]); + z = post_quant_conv->forward(ctx, z); // [N, z_channels, h, w] + } + auto decoder = std::dynamic_pointer_cast(blocks["decoder"]); + + ggml_set_name(z, "bench-start"); + auto h = decoder->forward(ctx, z); + ggml_set_name(h, "bench-end"); + return h; + } + + struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) { + // x: [N, in_channels, h, w] + auto encoder = std::dynamic_pointer_cast(blocks["encoder"]); + + auto z = encoder->forward(ctx, x); // [N, 2*z_channels, h/8, w/8] + if (use_quant) { + auto quant_conv = std::dynamic_pointer_cast(blocks["quant_conv"]); + z = quant_conv->forward(ctx, z); // [N, 2*embed_dim, h/8, w/8] + } + if (sd_version_is_flux2(version)) { + z = ggml_ext_chunk(ctx->ggml_ctx, z, 2, 2)[0]; + + // [N, C, H, W] -> [N, C*p*p, H/p, W/p] + int64_t p = 2; + int64_t N = z->ne[3]; + int64_t C = z->ne[2]; + int64_t H = z->ne[1]; + int64_t W = z->ne[0]; + int64_t h = H / p; + int64_t w = W / p; + + z = ggml_reshape_4d(ctx->ggml_ctx, z, p, w, p, h * C * N); // [N*C*h, p, w, p] + z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 0, 2, 1, 3)); // [N*C*h, w, p, p] + z = ggml_reshape_4d(ctx->ggml_ctx, z, p * p, w * h, C, N); // [N, C, h*w, p*p] + z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 1, 0, 2, 3)); // [N, C, p*p, h*w] + z = ggml_reshape_4d(ctx->ggml_ctx, z, w, h, p * p * C, N); // [N, C*p*p, h*w] + } + return z; + } + + int get_encoder_output_channels() { + int factor = dd_config.double_z ? 2 : 1; + return dd_config.z_channels * factor; + } +}; + +struct AutoEncoderKL : public VAE { + float scale_factor = 1.f; + float shift_factor = 0.f; + bool decode_only = true; + AutoEncoderKLModel ae; + + AutoEncoderKL(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map, + const std::string prefix, + bool decode_only = false, + bool use_video_decoder = false, + SDVersion version = VERSION_SD1) + : decode_only(decode_only), VAE(version, backend, offload_params_to_cpu) { + if (sd_version_is_sd1(version) || sd_version_is_sd2(version)) { + scale_factor = 0.18215f; + shift_factor = 0.f; + } else if (sd_version_is_sdxl(version)) { + scale_factor = 0.13025f; + shift_factor = 0.f; + } else if (sd_version_is_sd3(version)) { + scale_factor = 1.5305f; + shift_factor = 0.0609f; + } else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) { + scale_factor = 0.3611f; + shift_factor = 0.1159f; + } else if (sd_version_is_flux2(version)) { + scale_factor = 1.0f; + shift_factor = 0.f; + } + bool use_linear_projection = false; + for (const auto& [name, tensor_storage] : tensor_storage_map) { + if (!starts_with(name, prefix)) { + continue; + } + if (ends_with(name, "attn_1.proj_out.weight")) { + if (tensor_storage.n_dims == 2) { + use_linear_projection = true; + } + break; + } + } + ae = AutoEncoderKLModel(version, decode_only, use_linear_projection, use_video_decoder); + ae.init(params_ctx, tensor_storage_map, prefix); + } + + void set_conv2d_scale(float scale) override { + std::vector blocks; + ae.get_all_blocks(blocks); + for (auto block : blocks) { + if (block->get_desc() == "Conv2d") { + auto conv_block = (Conv2d*)block; + conv_block->set_scale(scale); + } + } + } + + std::string get_desc() override { + return "vae"; + } + + void get_param_tensors(std::map& tensors, const std::string prefix) override { + ae.get_param_tensors(tensors, prefix); + } + + struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) { + struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); + + z = to_backend(z); + + auto runner_ctx = get_context(); + + struct ggml_tensor* out = decode_graph ? ae.decode(&runner_ctx, z) : ae.encode(&runner_ctx, z); + + ggml_build_forward_expand(gf, out); + + return gf; + } + + bool _compute(const int n_threads, + struct ggml_tensor* z, + bool decode_graph, + struct ggml_tensor** output, + struct ggml_context* output_ctx = nullptr) override { + GGML_ASSERT(!decode_only || decode_graph); + auto get_graph = [&]() -> struct ggml_cgraph* { + return build_graph(z, decode_graph); + }; + // ggml_set_f32(z, 0.5f); + // print_ggml_tensor(z); + return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); + } + + ggml_tensor* gaussian_latent_sample(ggml_context* work_ctx, ggml_tensor* moments, std::shared_ptr rng) { + // ldm.modules.distributions.distributions.DiagonalGaussianDistribution.sample + ggml_tensor* latents = ggml_new_tensor_4d(work_ctx, moments->type, moments->ne[0], moments->ne[1], moments->ne[2] / 2, moments->ne[3]); + struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, latents); + ggml_ext_im_set_randn_f32(noise, rng); + { + float mean = 0; + float logvar = 0; + float value = 0; + float std_ = 0; + for (int i = 0; i < latents->ne[3]; i++) { + for (int j = 0; j < latents->ne[2]; j++) { + for (int k = 0; k < latents->ne[1]; k++) { + for (int l = 0; l < latents->ne[0]; l++) { + mean = ggml_ext_tensor_get_f32(moments, l, k, j, i); + logvar = ggml_ext_tensor_get_f32(moments, l, k, j + (int)latents->ne[2], i); + logvar = std::max(-30.0f, std::min(logvar, 20.0f)); + std_ = std::exp(0.5f * logvar); + value = mean + std_ * ggml_ext_tensor_get_f32(noise, l, k, j, i); + // printf("%d %d %d %d -> %f\n", i, j, k, l, value); + ggml_ext_tensor_set_f32(latents, value, l, k, j, i); + } + } + } + } + } + return latents; + } + + ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr rng) { + if (sd_version_is_flux2(version)) { + return vae_output; + } else if (version == VERSION_SD1_PIX2PIX) { + return ggml_view_3d(work_ctx, + vae_output, + vae_output->ne[0], + vae_output->ne[1], + vae_output->ne[2] / 2, + vae_output->nb[1], + vae_output->nb[2], + 0); + } else { + return gaussian_latent_sample(work_ctx, vae_output, rng); + } + } + + void get_latents_mean_std_vec(ggml_tensor* latents, int channel_dim, std::vector& latents_mean_vec, std::vector& latents_std_vec) { + // flux2 + if (sd_version_is_flux2(version)) { + GGML_ASSERT(latents->ne[channel_dim] == 128); + latents_mean_vec = {-0.0676f, -0.0715f, -0.0753f, -0.0745f, 0.0223f, 0.0180f, 0.0142f, 0.0184f, + -0.0001f, -0.0063f, -0.0002f, -0.0031f, -0.0272f, -0.0281f, -0.0276f, -0.0290f, + -0.0769f, -0.0672f, -0.0902f, -0.0892f, 0.0168f, 0.0152f, 0.0079f, 0.0086f, + 0.0083f, 0.0015f, 0.0003f, -0.0043f, -0.0439f, -0.0419f, -0.0438f, -0.0431f, + -0.0102f, -0.0132f, -0.0066f, -0.0048f, -0.0311f, -0.0306f, -0.0279f, -0.0180f, + 0.0030f, 0.0015f, 0.0126f, 0.0145f, 0.0347f, 0.0338f, 0.0337f, 0.0283f, + 0.0020f, 0.0047f, 0.0047f, 0.0050f, 0.0123f, 0.0081f, 0.0081f, 0.0146f, + 0.0681f, 0.0679f, 0.0767f, 0.0732f, -0.0462f, -0.0474f, -0.0392f, -0.0511f, + -0.0528f, -0.0477f, -0.0470f, -0.0517f, -0.0317f, -0.0316f, -0.0345f, -0.0283f, + 0.0510f, 0.0445f, 0.0578f, 0.0458f, -0.0412f, -0.0458f, -0.0487f, -0.0467f, + -0.0088f, -0.0106f, -0.0088f, -0.0046f, -0.0376f, -0.0432f, -0.0436f, -0.0499f, + 0.0118f, 0.0166f, 0.0203f, 0.0279f, 0.0113f, 0.0129f, 0.0016f, 0.0072f, + -0.0118f, -0.0018f, -0.0141f, -0.0054f, -0.0091f, -0.0138f, -0.0145f, -0.0187f, + 0.0323f, 0.0305f, 0.0259f, 0.0300f, 0.0540f, 0.0614f, 0.0495f, 0.0590f, + -0.0511f, -0.0603f, -0.0478f, -0.0524f, -0.0227f, -0.0274f, -0.0154f, -0.0255f, + -0.0572f, -0.0565f, -0.0518f, -0.0496f, 0.0116f, 0.0054f, 0.0163f, 0.0104f}; + latents_std_vec = { + 1.8029f, 1.7786f, 1.7868f, 1.7837f, 1.7717f, 1.7590f, 1.7610f, 1.7479f, + 1.7336f, 1.7373f, 1.7340f, 1.7343f, 1.8626f, 1.8527f, 1.8629f, 1.8589f, + 1.7593f, 1.7526f, 1.7556f, 1.7583f, 1.7363f, 1.7400f, 1.7355f, 1.7394f, + 1.7342f, 1.7246f, 1.7392f, 1.7304f, 1.7551f, 1.7513f, 1.7559f, 1.7488f, + 1.8449f, 1.8454f, 1.8550f, 1.8535f, 1.8240f, 1.7813f, 1.7854f, 1.7945f, + 1.8047f, 1.7876f, 1.7695f, 1.7676f, 1.7782f, 1.7667f, 1.7925f, 1.7848f, + 1.7579f, 1.7407f, 1.7483f, 1.7368f, 1.7961f, 1.7998f, 1.7920f, 1.7925f, + 1.7780f, 1.7747f, 1.7727f, 1.7749f, 1.7526f, 1.7447f, 1.7657f, 1.7495f, + 1.7775f, 1.7720f, 1.7813f, 1.7813f, 1.8162f, 1.8013f, 1.8023f, 1.8033f, + 1.7527f, 1.7331f, 1.7563f, 1.7482f, 1.7610f, 1.7507f, 1.7681f, 1.7613f, + 1.7665f, 1.7545f, 1.7828f, 1.7726f, 1.7896f, 1.7999f, 1.7864f, 1.7760f, + 1.7613f, 1.7625f, 1.7560f, 1.7577f, 1.7783f, 1.7671f, 1.7810f, 1.7799f, + 1.7201f, 1.7068f, 1.7265f, 1.7091f, 1.7793f, 1.7578f, 1.7502f, 1.7455f, + 1.7587f, 1.7500f, 1.7525f, 1.7362f, 1.7616f, 1.7572f, 1.7444f, 1.7430f, + 1.7509f, 1.7610f, 1.7634f, 1.7612f, 1.7254f, 1.7135f, 1.7321f, 1.7226f, + 1.7664f, 1.7624f, 1.7718f, 1.7664f, 1.7457f, 1.7441f, 1.7569f, 1.7530f}; + } else { + GGML_ABORT("unknown version %d", version); + } + } + + ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) { + ggml_tensor* vae_latents = ggml_dup(work_ctx, latents); + if (sd_version_is_flux2(version)) { + int channel_dim = 2; + std::vector latents_mean_vec; + std::vector latents_std_vec; + get_latents_mean_std_vec(latents, channel_dim, latents_mean_vec, latents_std_vec); + + float mean; + float std_; + for (int i = 0; i < latents->ne[3]; i++) { + if (channel_dim == 3) { + mean = latents_mean_vec[i]; + std_ = latents_std_vec[i]; + } + for (int j = 0; j < latents->ne[2]; j++) { + if (channel_dim == 2) { + mean = latents_mean_vec[j]; + std_ = latents_std_vec[j]; + } + for (int k = 0; k < latents->ne[1]; k++) { + for (int l = 0; l < latents->ne[0]; l++) { + float value = ggml_ext_tensor_get_f32(latents, l, k, j, i); + value = value * std_ / scale_factor + mean; + ggml_ext_tensor_set_f32(vae_latents, value, l, k, j, i); + } + } + } + } + } else { + ggml_ext_tensor_iter(latents, [&](ggml_tensor* latents, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { + float value = ggml_ext_tensor_get_f32(latents, i0, i1, i2, i3); + value = (value / scale_factor) + shift_factor; + ggml_ext_tensor_set_f32(vae_latents, value, i0, i1, i2, i3); + }); + } + return vae_latents; + } + + ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) { + ggml_tensor* diffusion_latents = ggml_dup(work_ctx, latents); + if (sd_version_is_flux2(version)) { + int channel_dim = 2; + std::vector latents_mean_vec; + std::vector latents_std_vec; + get_latents_mean_std_vec(latents, channel_dim, latents_mean_vec, latents_std_vec); + + float mean; + float std_; + for (int i = 0; i < latents->ne[3]; i++) { + if (channel_dim == 3) { + mean = latents_mean_vec[i]; + std_ = latents_std_vec[i]; + } + for (int j = 0; j < latents->ne[2]; j++) { + if (channel_dim == 2) { + mean = latents_mean_vec[j]; + std_ = latents_std_vec[j]; + } + for (int k = 0; k < latents->ne[1]; k++) { + for (int l = 0; l < latents->ne[0]; l++) { + float value = ggml_ext_tensor_get_f32(latents, l, k, j, i); + value = (value - mean) * scale_factor / std_; + ggml_ext_tensor_set_f32(diffusion_latents, value, l, k, j, i); + } + } + } + } + } else { + ggml_ext_tensor_iter(latents, [&](ggml_tensor* latents, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { + float value = ggml_ext_tensor_get_f32(latents, i0, i1, i2, i3); + value = (value - shift_factor) * scale_factor; + ggml_ext_tensor_set_f32(diffusion_latents, value, i0, i1, i2, i3); + }); + } + return diffusion_latents; + } + + int get_encoder_output_channels(int input_channels) { + return ae.get_encoder_output_channels(); + } + + void test() { + struct ggml_init_params params; + params.mem_size = static_cast(10 * 1024 * 1024); // 10 MB + params.mem_buffer = nullptr; + params.no_alloc = false; + + struct ggml_context* work_ctx = ggml_init(params); + GGML_ASSERT(work_ctx != nullptr); + + { + // CPU, x{1, 3, 64, 64}: Pass + // CUDA, x{1, 3, 64, 64}: Pass, but sill get wrong result for some image, may be due to interlnal nan + // CPU, x{2, 3, 64, 64}: Wrong result + // CUDA, x{2, 3, 64, 64}: Wrong result, and different from CPU result + auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 64, 64, 3, 2); + ggml_set_f32(x, 0.5f); + print_ggml_tensor(x); + struct ggml_tensor* out = nullptr; + + int64_t t0 = ggml_time_ms(); + _compute(8, x, false, &out, work_ctx); + int64_t t1 = ggml_time_ms(); + + print_ggml_tensor(out); + LOG_DEBUG("encode test done in %lldms", t1 - t0); + } + + if (false) { + // CPU, z{1, 4, 8, 8}: Pass + // CUDA, z{1, 4, 8, 8}: Pass + // CPU, z{3, 4, 8, 8}: Wrong result + // CUDA, z{3, 4, 8, 8}: Wrong result, and different from CPU result + auto z = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 4, 1); + ggml_set_f32(z, 0.5f); + print_ggml_tensor(z); + struct ggml_tensor* out = nullptr; + + int64_t t0 = ggml_time_ms(); + _compute(8, z, true, &out, work_ctx); + int64_t t1 = ggml_time_ms(); + + print_ggml_tensor(out); + LOG_DEBUG("decode test done in %lldms", t1 - t0); + } + }; +}; + +#endif // __AUTO_ENCODER_KL_HPP__ \ No newline at end of file diff --git a/src/ggml_extend.hpp b/src/ggml_extend.hpp index 954aee2..a51976e 100644 --- a/src/ggml_extend.hpp +++ b/src/ggml_extend.hpp @@ -377,6 +377,12 @@ __STATIC_INLINE__ void copy_ggml_tensor(struct ggml_tensor* dst, struct ggml_ten ggml_free(ctx); } +__STATIC_INLINE__ ggml_tensor* ggml_ext_dup_and_cpy_tensor(ggml_context* ctx, ggml_tensor* src) { + ggml_tensor* dup = ggml_dup_tensor(ctx, src); + copy_ggml_tensor(dup, src); + return dup; +} + __STATIC_INLINE__ float sigmoid(float x) { return 1 / (1.0f + expf(-x)); } @@ -637,7 +643,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_tensor_concat(struct ggml_context } // convert values from [0, 1] to [-1, 1] -__STATIC_INLINE__ void process_vae_input_tensor(struct ggml_tensor* src) { +__STATIC_INLINE__ void scale_to_minus1_1(struct ggml_tensor* src) { int64_t nelements = ggml_nelements(src); float* data = (float*)src->data; for (int i = 0; i < nelements; i++) { @@ -647,7 +653,7 @@ __STATIC_INLINE__ void process_vae_input_tensor(struct ggml_tensor* src) { } // convert values from [-1, 1] to [0, 1] -__STATIC_INLINE__ void process_vae_output_tensor(struct ggml_tensor* src) { +__STATIC_INLINE__ void scale_to_0_1(struct ggml_tensor* src) { int64_t nelements = ggml_nelements(src); float* data = (float*)src->data; for (int i = 0; i < nelements; i++) { @@ -834,7 +840,8 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input, const float tile_overlap_factor, const bool circular_x, const bool circular_y, - on_tile_process on_processing) { + on_tile_process on_processing, + bool slient = false) { output = ggml_set_f32(output, 0); int input_width = (int)input->ne[0]; @@ -864,8 +871,10 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input, float tile_overlap_factor_y; sd_tiling_calc_tiles(num_tiles_y, tile_overlap_factor_y, small_height, p_tile_size_y, tile_overlap_factor, circular_y); - LOG_DEBUG("num tiles : %d, %d ", num_tiles_x, num_tiles_y); - LOG_DEBUG("optimal overlap : %f, %f (targeting %f)", tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor); + if (!slient) { + LOG_DEBUG("num tiles : %d, %d ", num_tiles_x, num_tiles_y); + LOG_DEBUG("optimal overlap : %f, %f (targeting %f)", tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor); + } int tile_overlap_x = (int32_t)(p_tile_size_x * tile_overlap_factor_x); int non_tile_overlap_x = p_tile_size_x - tile_overlap_x; @@ -896,7 +905,9 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input, params.mem_buffer = nullptr; params.no_alloc = false; - LOG_DEBUG("tile work buffer size: %.2f MB", params.mem_size / 1024.f / 1024.f); + if (!slient) { + LOG_DEBUG("tile work buffer size: %.2f MB", params.mem_size / 1024.f / 1024.f); + } // draft context struct ggml_context* tiles_ctx = ggml_init(params); @@ -909,8 +920,10 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input, ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size_x, input_tile_size_y, input->ne[2], input->ne[3]); ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size_x, output_tile_size_y, output->ne[2], output->ne[3]); int num_tiles = num_tiles_x * num_tiles_y; - LOG_DEBUG("processing %i tiles", num_tiles); - pretty_progress(0, num_tiles, 0.0f); + if (!slient) { + LOG_DEBUG("processing %i tiles", num_tiles); + pretty_progress(0, num_tiles, 0.0f); + } int tile_count = 1; bool last_y = false, last_x = false; float last_time = 0.0f; @@ -960,8 +973,10 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input, } last_x = false; } - if (tile_count < num_tiles) { - pretty_progress(num_tiles, num_tiles, last_time); + if (!slient) { + if (tile_count < num_tiles) { + pretty_progress(num_tiles, num_tiles, last_time); + } } ggml_free(tiles_ctx); } diff --git a/src/model.cpp b/src/model.cpp index 77b032c..87b6545 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -1104,10 +1104,12 @@ SDVersion ModelLoader::get_sd_version() { tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) { has_middle_block_1 = true; } - if (tensor_storage.name.find("model.diffusion_model.output_blocks.3.1.transformer_blocks.1") != std::string::npos) { + if (tensor_storage.name.find("model.diffusion_model.output_blocks.3.1.transformer_blocks.1") != std::string::npos || + tensor_storage.name.find("unet.up_blocks.1.attentions.0.transformer_blocks.1") != std::string::npos) { has_output_block_311 = true; } - if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos) { + if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos || + tensor_storage.name.find("unet.up_blocks.2.attentions.1") != std::string::npos) { has_output_block_71 = true; } if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" || diff --git a/src/name_conversion.cpp b/src/name_conversion.cpp index 3b3abfb..d5d5e05 100644 --- a/src/name_conversion.cpp +++ b/src/name_conversion.cpp @@ -1120,7 +1120,11 @@ std::string convert_tensor_name(std::string name, SDVersion version) { for (const auto& prefix : first_stage_model_prefix_vec) { if (starts_with(name, prefix)) { name = convert_first_stage_model_name(name.substr(prefix.size()), prefix); - name = prefix + name; + if (version == VERSION_SDXS) { + name = "tae." + name; + } else { + name = prefix + name; + } break; } } diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index 613ebb0..25bce01 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -7,6 +7,7 @@ #include "stable-diffusion.h" #include "util.h" +#include "auto_encoder_kl.hpp" #include "cache_dit.hpp" #include "conditioner.hpp" #include "control.hpp" @@ -90,14 +91,6 @@ void calculate_alphas_cumprod(float* alphas_cumprod, } } -void suppress_pp(int step, int steps, float time, void* data) { - (void)step; - (void)steps; - (void)time; - (void)data; - return; -} - /*=============================================== StableDiffusionGGML ================================================*/ class StableDiffusionGGML { @@ -118,8 +111,6 @@ public: std::shared_ptr rng = std::make_shared(); std::shared_ptr sampler_rng = nullptr; int n_threads = -1; - float scale_factor = 0.18215f; - float shift_factor = 0.f; float default_flow_shift = INFINITY; std::shared_ptr cond_stage_model; @@ -127,7 +118,7 @@ public: std::shared_ptr diffusion_model; std::shared_ptr high_noise_diffusion_model; std::shared_ptr first_stage_model; - std::shared_ptr tae_first_stage; + std::shared_ptr preview_vae; std::shared_ptr control_net; std::shared_ptr pmid_model; std::shared_ptr pmid_lora; @@ -138,7 +129,6 @@ public: bool apply_lora_immediately = false; std::string taesd_path; - bool use_tiny_autoencoder = false; sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0, 0}; bool offload_params_to_cpu = false; bool use_pmid = false; @@ -239,10 +229,10 @@ public: n_threads = sd_ctx_params->n_threads; vae_decode_only = sd_ctx_params->vae_decode_only; free_params_immediately = sd_ctx_params->free_params_immediately; - taesd_path = SAFE_STR(sd_ctx_params->taesd_path); - use_tiny_autoencoder = taesd_path.size() > 0; offload_params_to_cpu = sd_ctx_params->offload_params_to_cpu; + bool use_tae = false; + rng = get_rng(sd_ctx_params->rng_type); if (sd_ctx_params->sampler_rng_type != RNG_TYPE_COUNT && sd_ctx_params->sampler_rng_type != sd_ctx_params->rng_type) { sampler_rng = get_rng(sd_ctx_params->sampler_rng_type); @@ -332,6 +322,14 @@ public: } } + if (strlen(SAFE_STR(sd_ctx_params->taesd_path)) > 0) { + LOG_INFO("loading tae from '%s'", sd_ctx_params->taesd_path); + if (!model_loader.init_from_file(sd_ctx_params->taesd_path, "tae.")) { + LOG_WARN("loading tae from '%s' failed", sd_ctx_params->taesd_path); + } + use_tae = true; + } + model_loader.convert_tensors_name(); version = model_loader.get_sd_version(); @@ -400,22 +398,6 @@ public: apply_lora_immediately = false; } - if (sd_version_is_sdxl(version)) { - scale_factor = 0.13025f; - } else if (sd_version_is_sd3(version)) { - scale_factor = 1.5305f; - shift_factor = 0.0609f; - } else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) { - scale_factor = 0.3611f; - shift_factor = 0.1159f; - } else if (sd_version_is_wan(version) || - sd_version_is_qwen_image(version) || - sd_version_is_anima(version) || - sd_version_is_flux2(version)) { - scale_factor = 1.0f; - shift_factor = 0.f; - } - if (sd_version_is_control(version)) { // Might need vae encode for control cond vae_decode_only = false; @@ -424,6 +406,7 @@ public: bool tae_preview_only = sd_ctx_params->tae_preview_only; if (version == VERSION_SDXS) { tae_preview_only = false; + use_tae = true; } if (sd_ctx_params->circular_x || sd_ctx_params->circular_y) { @@ -610,31 +593,46 @@ public: vae_backend = backend; } - if (!(use_tiny_autoencoder || version == VERSION_SDXS) || tae_preview_only) { - if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version)) { - first_stage_model = std::make_shared(vae_backend, - offload_params_to_cpu, - tensor_storage_map, - "first_stage_model", - vae_decode_only, - version); - first_stage_model->alloc_params_buffer(); - first_stage_model->get_param_tensors(tensors, "first_stage_model"); - } else if (version == VERSION_CHROMA_RADIANCE) { - first_stage_model = std::make_shared(vae_backend, - offload_params_to_cpu); + auto create_tae = [&]() -> std::shared_ptr { + if (sd_version_is_wan(version) || + sd_version_is_qwen_image(version) || + sd_version_is_anima(version)) { + return std::make_shared(vae_backend, + offload_params_to_cpu, + tensor_storage_map, + "decoder", + vae_decode_only, + version); + } else { - first_stage_model = std::make_shared(vae_backend, + auto model = std::make_shared(vae_backend, offload_params_to_cpu, tensor_storage_map, - "first_stage_model", + "decoder.layers", vae_decode_only, - false, version); - if (sd_ctx_params->vae_conv_direct) { - LOG_INFO("Using Conv2d direct in the vae model"); - first_stage_model->set_conv2d_direct_enabled(true); - } + return model; + } + }; + + auto create_vae = [&]() -> std::shared_ptr { + if (sd_version_is_wan(version) || + sd_version_is_qwen_image(version) || + sd_version_is_anima(version)) { + return std::make_shared(vae_backend, + offload_params_to_cpu, + tensor_storage_map, + "first_stage_model", + vae_decode_only, + version); + } else { + auto model = std::make_shared(vae_backend, + offload_params_to_cpu, + tensor_storage_map, + "first_stage_model", + vae_decode_only, + false, + version); if (sd_version_is_sdxl(version) && (strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale || external_vae_is_invalid)) { float vae_conv_2d_scale = 1.f / 32.f; @@ -642,35 +640,40 @@ public: "No valid VAE specified with --vae or --force-sdxl-vae-conv-scale flag set, " "using Conv2D scale %.3f", vae_conv_2d_scale); - first_stage_model->set_conv2d_scale(vae_conv_2d_scale); + model->set_conv2d_scale(vae_conv_2d_scale); } - first_stage_model->alloc_params_buffer(); - first_stage_model->get_param_tensors(tensors, "first_stage_model"); + return model; + } + }; + + if (version == VERSION_CHROMA_RADIANCE) { + LOG_INFO("using FakeVAE"); + first_stage_model = std::make_shared(version, + vae_backend, + offload_params_to_cpu); + } else if (use_tae && !tae_preview_only) { + LOG_INFO("using TAE for encoding / decoding"); + first_stage_model = create_tae(); + first_stage_model->alloc_params_buffer(); + first_stage_model->get_param_tensors(tensors, "tae"); + } else { + LOG_INFO("using VAE for encoding / decoding"); + first_stage_model = create_vae(); + first_stage_model->alloc_params_buffer(); + first_stage_model->get_param_tensors(tensors, "first_stage_model"); + if (use_tae && tae_preview_only) { + LOG_INFO("using TAE for preview"); + preview_vae = create_tae(); + preview_vae->alloc_params_buffer(); + preview_vae->get_param_tensors(tensors, "tae"); } } - if (use_tiny_autoencoder || version == VERSION_SDXS) { - if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version)) { - tae_first_stage = std::make_shared(vae_backend, - offload_params_to_cpu, - tensor_storage_map, - "decoder", - vae_decode_only, - version); - } else { - tae_first_stage = std::make_shared(vae_backend, - offload_params_to_cpu, - tensor_storage_map, - "decoder.layers", - vae_decode_only, - version); - if (version == VERSION_SDXS) { - tae_first_stage->alloc_params_buffer(); - tae_first_stage->get_param_tensors(tensors, "first_stage_model"); - } - } - if (sd_ctx_params->vae_conv_direct) { - LOG_INFO("Using Conv2d direct in the tae model"); - tae_first_stage->set_conv2d_direct_enabled(true); + + if (sd_ctx_params->vae_conv_direct) { + LOG_INFO("Using Conv2d direct in the vae model"); + first_stage_model->set_conv2d_direct_enabled(true); + if (preview_vae) { + preview_vae->set_conv2d_direct_enabled(true); } } @@ -743,8 +746,8 @@ public: if (first_stage_model) { first_stage_model->set_flash_attention_enabled(true); } - if (tae_first_stage) { - tae_first_stage->set_flash_attention_enabled(true); + if (preview_vae) { + preview_vae->set_flash_attention_enabled(true); } } @@ -782,7 +785,7 @@ public: std::set ignore_tensors; tensors["alphas_cumprod"] = alphas_cumprod_tensor; - if (use_tiny_autoencoder) { + if (use_tae && !tae_preview_only) { ignore_tensors.insert("first_stage_model."); } if (use_pmid) { @@ -796,6 +799,7 @@ public: ignore_tensors.insert("first_stage_model.encoder"); ignore_tensors.insert("first_stage_model.conv1"); ignore_tensors.insert("first_stage_model.quant"); + ignore_tensors.insert("tae.encoder"); ignore_tensors.insert("text_encoders.llm.visual."); } if (version == VERSION_OVIS_IMAGE) { @@ -822,15 +826,9 @@ public: unet_params_mem_size += high_noise_diffusion_model->get_params_buffer_size(); } size_t vae_params_mem_size = 0; - if (!(use_tiny_autoencoder || version == VERSION_SDXS) || tae_preview_only) { - vae_params_mem_size = first_stage_model->get_params_buffer_size(); - } - if (use_tiny_autoencoder || version == VERSION_SDXS) { - if (use_tiny_autoencoder && !tae_first_stage->load_from_file(taesd_path, n_threads)) { - return false; - } - use_tiny_autoencoder = true; // now the processing is identical for VERSION_SDXS - vae_params_mem_size = tae_first_stage->get_params_buffer_size(); + vae_params_mem_size = first_stage_model->get_params_buffer_size(); + if (preview_vae) { + vae_params_mem_size += preview_vae->get_params_buffer_size(); } size_t control_net_params_mem_size = 0; if (control_net) { @@ -983,7 +981,6 @@ public: } ggml_free(ctx); - use_tiny_autoencoder = use_tiny_autoencoder && !tae_preview_only; return true; } @@ -1422,8 +1419,7 @@ public: ggml_ext_tensor_scale_inplace(noise, augmentation_level); ggml_ext_tensor_add_inplace(init_img, noise); } - ggml_tensor* moments = vae_encode(work_ctx, init_img); - c_concat = get_first_stage_encoding(work_ctx, moments); + c_concat = encode_first_stage(work_ctx, init_img); } } @@ -1475,14 +1471,6 @@ public: } } - void silent_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) { - sd_progress_cb_t cb = sd_get_progress_callback(); - void* cbd = sd_get_progress_callback_data(); - sd_set_progress_callback((sd_progress_cb_t)suppress_pp, nullptr); - sd_tiling(input, output, scale, tile_size, tile_overlap_factor, circular_x, circular_y, on_processing); - sd_set_progress_callback(cb, cbd); - } - void preview_image(ggml_context* work_ctx, int step, struct ggml_tensor* latents, @@ -1575,37 +1563,14 @@ public: free(data); free(images); } else { - if (preview_mode == PREVIEW_VAE) { - process_latent_out(latents); - if (vae_tiling_params.enabled) { - // split latent in 32x32 tiles and compute in several steps - auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { - return first_stage_model->compute(n_threads, in, true, &out, nullptr); - }; - silent_tiling(latents, result, get_vae_scale_factor(), 32, 0.5f, on_tiling); - + if (preview_mode == PREVIEW_VAE || preview_mode == PREVIEW_TAE) { + if (preview_vae) { + latents = preview_vae->diffusion_to_vae_latents(work_ctx, latents); + result = preview_vae->decode(n_threads, work_ctx, latents, vae_tiling_params, false, circular_x, circular_y, result, true); } else { - first_stage_model->compute(n_threads, latents, true, &result, work_ctx); + latents = first_stage_model->diffusion_to_vae_latents(work_ctx, latents); + result = first_stage_model->decode(n_threads, work_ctx, latents, vae_tiling_params, false, circular_x, circular_y, result, true); } - - first_stage_model->free_compute_buffer(); - process_vae_output_tensor(result); - process_latent_in(latents); - } else if (preview_mode == PREVIEW_TAE) { - if (tae_first_stage == nullptr) { - LOG_WARN("TAE not found for preview"); - return; - } - if (vae_tiling_params.enabled) { - // split latent in 64x64 tiles and compute in several steps - auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { - return tae_first_stage->compute(n_threads, in, true, &out, nullptr); - }; - silent_tiling(latents, result, get_vae_scale_factor(), 64, 0.5f, on_tiling); - } else { - tae_first_stage->compute(n_threads, latents, true, &result, work_ctx); - } - tae_first_stage->free_compute_buffer(); } else { return; } @@ -1829,8 +1794,7 @@ public: } size_t steps = sigmas.size() - 1; - struct ggml_tensor* x = ggml_dup_tensor(work_ctx, init_latent); - copy_ggml_tensor(x, init_latent); + struct ggml_tensor* x = ggml_ext_dup_and_cpy_tensor(work_ctx, init_latent); if (noise) { x = denoiser->noise_scaling(sigmas[0], noise, x); @@ -2351,15 +2315,7 @@ public: } int get_vae_scale_factor() { - int vae_scale_factor = 8; - if (version == VERSION_WAN2_2_TI2V) { - vae_scale_factor = 16; - } else if (sd_version_is_flux2(version)) { - vae_scale_factor = 16; - } else if (version == VERSION_CHROMA_RADIANCE) { - vae_scale_factor = 1; - } - return vae_scale_factor; + return first_stage_model->get_scale_factor(); } int get_diffusion_model_down_factor() { @@ -2414,383 +2370,28 @@ public: } else { init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1); } - ggml_set_f32(init_latent, shift_factor); + ggml_set_f32(init_latent, 0.f); return init_latent; } - void get_latents_mean_std_vec(ggml_tensor* latent, int channel_dim, std::vector& latents_mean_vec, std::vector& latents_std_vec) { - GGML_ASSERT(latent->ne[channel_dim] == 16 || latent->ne[channel_dim] == 48 || latent->ne[channel_dim] == 128); - if (latent->ne[channel_dim] == 16) { - latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f, - 0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f}; - latents_std_vec = {2.8184f, 1.4541f, 2.3275f, 2.6558f, 1.2196f, 1.7708f, 2.6052f, 2.0743f, - 3.2687f, 2.1526f, 2.8652f, 1.5579f, 1.6382f, 1.1253f, 2.8251f, 1.9160f}; - } else if (latent->ne[channel_dim] == 48) { - latents_mean_vec = {-0.2289f, -0.0052f, -0.1323f, -0.2339f, -0.2799f, 0.0174f, 0.1838f, 0.1557f, - -0.1382f, 0.0542f, 0.2813f, 0.0891f, 0.1570f, -0.0098f, 0.0375f, -0.1825f, - -0.2246f, -0.1207f, -0.0698f, 0.5109f, 0.2665f, -0.2108f, -0.2158f, 0.2502f, - -0.2055f, -0.0322f, 0.1109f, 0.1567f, -0.0729f, 0.0899f, -0.2799f, -0.1230f, - -0.0313f, -0.1649f, 0.0117f, 0.0723f, -0.2839f, -0.2083f, -0.0520f, 0.3748f, - 0.0152f, 0.1957f, 0.1433f, -0.2944f, 0.3573f, -0.0548f, -0.1681f, -0.0667f}; - latents_std_vec = { - 0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f, - 0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f, - 0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f, - 0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f, - 0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f, - 0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f}; - } else if (latent->ne[channel_dim] == 128) { - // flux2 - latents_mean_vec = {-0.0676f, -0.0715f, -0.0753f, -0.0745f, 0.0223f, 0.0180f, 0.0142f, 0.0184f, - -0.0001f, -0.0063f, -0.0002f, -0.0031f, -0.0272f, -0.0281f, -0.0276f, -0.0290f, - -0.0769f, -0.0672f, -0.0902f, -0.0892f, 0.0168f, 0.0152f, 0.0079f, 0.0086f, - 0.0083f, 0.0015f, 0.0003f, -0.0043f, -0.0439f, -0.0419f, -0.0438f, -0.0431f, - -0.0102f, -0.0132f, -0.0066f, -0.0048f, -0.0311f, -0.0306f, -0.0279f, -0.0180f, - 0.0030f, 0.0015f, 0.0126f, 0.0145f, 0.0347f, 0.0338f, 0.0337f, 0.0283f, - 0.0020f, 0.0047f, 0.0047f, 0.0050f, 0.0123f, 0.0081f, 0.0081f, 0.0146f, - 0.0681f, 0.0679f, 0.0767f, 0.0732f, -0.0462f, -0.0474f, -0.0392f, -0.0511f, - -0.0528f, -0.0477f, -0.0470f, -0.0517f, -0.0317f, -0.0316f, -0.0345f, -0.0283f, - 0.0510f, 0.0445f, 0.0578f, 0.0458f, -0.0412f, -0.0458f, -0.0487f, -0.0467f, - -0.0088f, -0.0106f, -0.0088f, -0.0046f, -0.0376f, -0.0432f, -0.0436f, -0.0499f, - 0.0118f, 0.0166f, 0.0203f, 0.0279f, 0.0113f, 0.0129f, 0.0016f, 0.0072f, - -0.0118f, -0.0018f, -0.0141f, -0.0054f, -0.0091f, -0.0138f, -0.0145f, -0.0187f, - 0.0323f, 0.0305f, 0.0259f, 0.0300f, 0.0540f, 0.0614f, 0.0495f, 0.0590f, - -0.0511f, -0.0603f, -0.0478f, -0.0524f, -0.0227f, -0.0274f, -0.0154f, -0.0255f, - -0.0572f, -0.0565f, -0.0518f, -0.0496f, 0.0116f, 0.0054f, 0.0163f, 0.0104f}; - latents_std_vec = { - 1.8029f, 1.7786f, 1.7868f, 1.7837f, 1.7717f, 1.7590f, 1.7610f, 1.7479f, - 1.7336f, 1.7373f, 1.7340f, 1.7343f, 1.8626f, 1.8527f, 1.8629f, 1.8589f, - 1.7593f, 1.7526f, 1.7556f, 1.7583f, 1.7363f, 1.7400f, 1.7355f, 1.7394f, - 1.7342f, 1.7246f, 1.7392f, 1.7304f, 1.7551f, 1.7513f, 1.7559f, 1.7488f, - 1.8449f, 1.8454f, 1.8550f, 1.8535f, 1.8240f, 1.7813f, 1.7854f, 1.7945f, - 1.8047f, 1.7876f, 1.7695f, 1.7676f, 1.7782f, 1.7667f, 1.7925f, 1.7848f, - 1.7579f, 1.7407f, 1.7483f, 1.7368f, 1.7961f, 1.7998f, 1.7920f, 1.7925f, - 1.7780f, 1.7747f, 1.7727f, 1.7749f, 1.7526f, 1.7447f, 1.7657f, 1.7495f, - 1.7775f, 1.7720f, 1.7813f, 1.7813f, 1.8162f, 1.8013f, 1.8023f, 1.8033f, - 1.7527f, 1.7331f, 1.7563f, 1.7482f, 1.7610f, 1.7507f, 1.7681f, 1.7613f, - 1.7665f, 1.7545f, 1.7828f, 1.7726f, 1.7896f, 1.7999f, 1.7864f, 1.7760f, - 1.7613f, 1.7625f, 1.7560f, 1.7577f, 1.7783f, 1.7671f, 1.7810f, 1.7799f, - 1.7201f, 1.7068f, 1.7265f, 1.7091f, 1.7793f, 1.7578f, 1.7502f, 1.7455f, - 1.7587f, 1.7500f, 1.7525f, 1.7362f, 1.7616f, 1.7572f, 1.7444f, 1.7430f, - 1.7509f, 1.7610f, 1.7634f, 1.7612f, 1.7254f, 1.7135f, 1.7321f, 1.7226f, - 1.7664f, 1.7624f, 1.7718f, 1.7664f, 1.7457f, 1.7441f, 1.7569f, 1.7530f}; - } - } - - void process_latent_in(ggml_tensor* latent) { - if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version) || sd_version_is_flux2(version)) { - int channel_dim = sd_version_is_flux2(version) ? 2 : 3; - std::vector latents_mean_vec; - std::vector latents_std_vec; - get_latents_mean_std_vec(latent, channel_dim, latents_mean_vec, latents_std_vec); - - float mean; - float std_; - for (int i = 0; i < latent->ne[3]; i++) { - if (channel_dim == 3) { - mean = latents_mean_vec[i]; - std_ = latents_std_vec[i]; - } - for (int j = 0; j < latent->ne[2]; j++) { - if (channel_dim == 2) { - mean = latents_mean_vec[i]; - std_ = latents_std_vec[i]; - } - for (int k = 0; k < latent->ne[1]; k++) { - for (int l = 0; l < latent->ne[0]; l++) { - float value = ggml_ext_tensor_get_f32(latent, l, k, j, i); - value = (value - mean) * scale_factor / std_; - ggml_ext_tensor_set_f32(latent, value, l, k, j, i); - } - } - } - } - } else if (version == VERSION_CHROMA_RADIANCE) { - // pass - } else { - ggml_ext_tensor_iter(latent, [&](ggml_tensor* latent, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { - float value = ggml_ext_tensor_get_f32(latent, i0, i1, i2, i3); - value = (value - shift_factor) * scale_factor; - ggml_ext_tensor_set_f32(latent, value, i0, i1, i2, i3); - }); - } - } - - void process_latent_out(ggml_tensor* latent) { - if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version) || sd_version_is_flux2(version)) { - int channel_dim = sd_version_is_flux2(version) ? 2 : 3; - std::vector latents_mean_vec; - std::vector latents_std_vec; - get_latents_mean_std_vec(latent, channel_dim, latents_mean_vec, latents_std_vec); - - float mean; - float std_; - for (int i = 0; i < latent->ne[3]; i++) { - if (channel_dim == 3) { - mean = latents_mean_vec[i]; - std_ = latents_std_vec[i]; - } - for (int j = 0; j < latent->ne[2]; j++) { - if (channel_dim == 2) { - mean = latents_mean_vec[i]; - std_ = latents_std_vec[i]; - } - for (int k = 0; k < latent->ne[1]; k++) { - for (int l = 0; l < latent->ne[0]; l++) { - float value = ggml_ext_tensor_get_f32(latent, l, k, j, i); - value = value * std_ / scale_factor + mean; - ggml_ext_tensor_set_f32(latent, value, l, k, j, i); - } - } - } - } - } else if (version == VERSION_CHROMA_RADIANCE) { - // pass - } else { - ggml_ext_tensor_iter(latent, [&](ggml_tensor* latent, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { - float value = ggml_ext_tensor_get_f32(latent, i0, i1, i2, i3); - value = (value / scale_factor) + shift_factor; - ggml_ext_tensor_set_f32(latent, value, i0, i1, i2, i3); - }); - } - } - - void get_tile_sizes(int& tile_size_x, - int& tile_size_y, - float& tile_overlap, - const sd_tiling_params_t& params, - int64_t latent_x, - int64_t latent_y, - float encoding_factor = 1.0f) { - tile_overlap = std::max(std::min(params.target_overlap, 0.5f), 0.0f); - auto get_tile_size = [&](int requested_size, float factor, int64_t latent_size) { - const int default_tile_size = 32; - const int min_tile_dimension = 4; - int tile_size = default_tile_size; - // factor <= 1 means simple fraction of the latent dimension - // factor > 1 means number of tiles across that dimension - if (factor > 0.f) { - if (factor > 1.0) - factor = 1 / (factor - factor * tile_overlap + tile_overlap); - tile_size = static_cast(std::round(latent_size * factor)); - } else if (requested_size >= min_tile_dimension) { - tile_size = requested_size; - } - tile_size = static_cast(tile_size * encoding_factor); - return std::max(std::min(tile_size, static_cast(latent_size)), min_tile_dimension); - }; - - tile_size_x = get_tile_size(params.tile_size_x, params.rel_size_x, latent_x); - tile_size_y = get_tile_size(params.tile_size_y, params.rel_size_y, latent_y); - } - - ggml_tensor* vae_encode(ggml_context* work_ctx, ggml_tensor* x) { - int64_t t0 = ggml_time_ms(); - ggml_tensor* result = nullptr; - const int vae_scale_factor = get_vae_scale_factor(); - int64_t W = x->ne[0] / vae_scale_factor; - int64_t H = x->ne[1] / vae_scale_factor; - int64_t C = get_latent_channel(); - if (vae_tiling_params.enabled) { - // TODO wan2.2 vae support? - int64_t ne2; - int64_t ne3; - if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) { - ne2 = 1; - ne3 = C * x->ne[3]; - } else { - int64_t out_channels = C; - bool encode_outputs_mu = use_tiny_autoencoder || - sd_version_is_wan(version) || - sd_version_is_flux2(version) || - version == VERSION_CHROMA_RADIANCE; - if (!encode_outputs_mu) { - out_channels *= 2; - } - ne2 = out_channels; - ne3 = x->ne[3]; - } - result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, ne2, ne3); - } - - if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) { - x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]); - } - - if (!use_tiny_autoencoder) { - process_vae_input_tensor(x); - if (vae_tiling_params.enabled) { - float tile_overlap; - int tile_size_x, tile_size_y; - // multiply tile size for encode to keep the compute buffer size consistent - get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, vae_tiling_params, W, H, 1.30539f); - - LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y); - - auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { - return first_stage_model->compute(n_threads, in, false, &out, work_ctx); - }; - sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, circular_x, circular_y, on_tiling); - } else { - first_stage_model->compute(n_threads, x, false, &result, work_ctx); - } - first_stage_model->free_compute_buffer(); - } else { - if (vae_tiling_params.enabled) { - // split latent in 32x32 tiles and compute in several steps - auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { - return tae_first_stage->compute(n_threads, in, false, &out, nullptr); - }; - sd_tiling(x, result, vae_scale_factor, 64, 0.5f, circular_x, circular_y, on_tiling); - } else { - tae_first_stage->compute(n_threads, x, false, &result, work_ctx); - } - tae_first_stage->free_compute_buffer(); - } - - int64_t t1 = ggml_time_ms(); - LOG_DEBUG("computing vae encode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); - return result; - } - - ggml_tensor* gaussian_latent_sample(ggml_context* work_ctx, ggml_tensor* moments) { - // ldm.modules.distributions.distributions.DiagonalGaussianDistribution.sample - ggml_tensor* latent = ggml_new_tensor_4d(work_ctx, moments->type, moments->ne[0], moments->ne[1], moments->ne[2] / 2, moments->ne[3]); - struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, latent); - ggml_ext_im_set_randn_f32(noise, rng); - { - float mean = 0; - float logvar = 0; - float value = 0; - float std_ = 0; - for (int i = 0; i < latent->ne[3]; i++) { - for (int j = 0; j < latent->ne[2]; j++) { - for (int k = 0; k < latent->ne[1]; k++) { - for (int l = 0; l < latent->ne[0]; l++) { - mean = ggml_ext_tensor_get_f32(moments, l, k, j, i); - logvar = ggml_ext_tensor_get_f32(moments, l, k, j + (int)latent->ne[2], i); - logvar = std::max(-30.0f, std::min(logvar, 20.0f)); - std_ = std::exp(0.5f * logvar); - value = mean + std_ * ggml_ext_tensor_get_f32(noise, l, k, j, i); - // printf("%d %d %d %d -> %f\n", i, j, k, l, value); - ggml_ext_tensor_set_f32(latent, value, l, k, j, i); - } - } - } - } - } - return latent; - } - - ggml_tensor* get_first_stage_encoding(ggml_context* work_ctx, ggml_tensor* vae_output) { - ggml_tensor* latent; - if (use_tiny_autoencoder || - sd_version_is_qwen_image(version) || - sd_version_is_anima(version) || - sd_version_is_wan(version) || - sd_version_is_flux2(version) || - version == VERSION_CHROMA_RADIANCE) { - latent = vae_output; - } else if (version == VERSION_SD1_PIX2PIX) { - latent = ggml_view_3d(work_ctx, - vae_output, - vae_output->ne[0], - vae_output->ne[1], - vae_output->ne[2] / 2, - vae_output->nb[1], - vae_output->nb[2], - 0); - } else { - latent = gaussian_latent_sample(work_ctx, vae_output); - } - if (!use_tiny_autoencoder && version != VERSION_SD1_PIX2PIX) { - process_latent_in(latent); - } - if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) { - latent = ggml_reshape_4d(work_ctx, latent, latent->ne[0], latent->ne[1], latent->ne[3], 1); - } - return latent; + ggml_tensor* encode_to_vae_latents(ggml_context* work_ctx, ggml_tensor* x) { + ggml_tensor* vae_output = first_stage_model->encode(n_threads, work_ctx, x, vae_tiling_params, circular_x, circular_y); + ggml_tensor* latents = first_stage_model->vae_output_to_latents(work_ctx, vae_output, rng); + return latents; } ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x) { - ggml_tensor* vae_output = vae_encode(work_ctx, x); - return get_first_stage_encoding(work_ctx, vae_output); + ggml_tensor* latents = encode_to_vae_latents(work_ctx, x); + if (version != VERSION_SD1_PIX2PIX) { + latents = first_stage_model->vae_to_diffuison_latents(work_ctx, latents); + } + return latents; } ggml_tensor* decode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) { - const int vae_scale_factor = get_vae_scale_factor(); - int64_t W = x->ne[0] * vae_scale_factor; - int64_t H = x->ne[1] * vae_scale_factor; - int64_t C = 3; - ggml_tensor* result = nullptr; - if (decode_video) { - int64_t T = x->ne[2]; - if (sd_version_is_wan(version)) { - T = ((T - 1) * 4) + 1; - } - result = ggml_new_tensor_4d(work_ctx, - GGML_TYPE_F32, - W, - H, - T, - 3); - } else { - result = ggml_new_tensor_4d(work_ctx, - GGML_TYPE_F32, - W, - H, - C, - x->ne[3]); - } - int64_t t0 = ggml_time_ms(); - if (!use_tiny_autoencoder) { - if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) { - x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]); - } - process_latent_out(x); - // x = load_tensor_from_file(work_ctx, "wan_vae_z.bin"); - if (vae_tiling_params.enabled) { - float tile_overlap; - int tile_size_x, tile_size_y; - get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, vae_tiling_params, x->ne[0], x->ne[1]); - - LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y); - - // split latent in 32x32 tiles and compute in several steps - auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { - return first_stage_model->compute(n_threads, in, true, &out, nullptr); - }; - sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, circular_x, circular_y, on_tiling); - } else { - if (!first_stage_model->compute(n_threads, x, true, &result, work_ctx)) { - LOG_ERROR("Failed to decode latetnts"); - first_stage_model->free_compute_buffer(); - return nullptr; - } - } - first_stage_model->free_compute_buffer(); - process_vae_output_tensor(result); - } else { - if (vae_tiling_params.enabled) { - // split latent in 64x64 tiles and compute in several steps - auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { - return tae_first_stage->compute(n_threads, in, true, &out); - }; - sd_tiling(x, result, vae_scale_factor, 64, 0.5f, circular_x, circular_y, on_tiling); - } else { - if (!tae_first_stage->compute(n_threads, x, true, &result)) { - LOG_ERROR("Failed to decode latetnts"); - tae_first_stage->free_compute_buffer(); - return nullptr; - } - } - tae_first_stage->free_compute_buffer(); - } - - int64_t t1 = ggml_time_ms(); - LOG_DEBUG("computing vae decode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); - ggml_ext_tensor_clamp_inplace(result, 0.0f, 1.0f); - return result; + x = first_stage_model->diffusion_to_vae_latents(work_ctx, x); + x = first_stage_model->decode(n_threads, work_ctx, x, vae_tiling_params, decode_video, circular_x, circular_y); + return x; } void set_flow_shift(float flow_shift = INFINITY) { @@ -3560,7 +3161,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, int64_t t4 = ggml_time_ms(); LOG_INFO("decode_first_stage completed, taking %.2fs", (t4 - t3) * 1.0f / 1000); - if (sd_ctx->sd->free_params_immediately && !sd_ctx->sd->use_tiny_autoencoder) { + if (sd_ctx->sd->free_params_immediately) { sd_ctx->sd->first_stage_model->free_params_buffer(); } @@ -3609,15 +3210,15 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g if (sd_ctx->sd->first_stage_model) { sd_ctx->sd->first_stage_model->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y); } - if (sd_ctx->sd->tae_first_stage) { - sd_ctx->sd->tae_first_stage->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y); + if (sd_ctx->sd->preview_vae) { + sd_ctx->sd->preview_vae->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y); } } else { int tile_size_x, tile_size_y; float _overlap; int latent_size_x = width / sd_ctx->sd->get_vae_scale_factor(); int latent_size_y = height / sd_ctx->sd->get_vae_scale_factor(); - sd_ctx->sd->get_tile_sizes(tile_size_x, tile_size_y, _overlap, sd_img_gen_params->vae_tiling_params, latent_size_x, latent_size_y); + sd_ctx->sd->first_stage_model->get_tile_sizes(tile_size_x, tile_size_y, _overlap, sd_img_gen_params->vae_tiling_params, latent_size_x, latent_size_y); // force disable circular padding for vae if tiling is enabled unless latent is smaller than tile size // otherwise it will cause artifacts at the edges of the tiles @@ -3627,8 +3228,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g if (sd_ctx->sd->first_stage_model) { sd_ctx->sd->first_stage_model->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y); } - if (sd_ctx->sd->tae_first_stage) { - sd_ctx->sd->tae_first_stage->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y); + if (sd_ctx->sd->preview_vae) { + sd_ctx->sd->preview_vae->set_circular_axes(sd_ctx->sd->circular_x, sd_ctx->sd->circular_y); } // disable circular tiling if it's enabled for the VAE @@ -4105,14 +3706,13 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s sd_image_to_ggml_tensor(sd_vid_gen_params->init_image, init_img); init_img = ggml_reshape_4d(work_ctx, init_img, width, height, 1, 3); - auto init_image_latent = sd_ctx->sd->vae_encode(work_ctx, init_img); // [b*c, 1, h/16, w/16] + auto init_image_latent = sd_ctx->sd->encode_to_vae_latents(work_ctx, init_img); // [b*c, 1, h/16, w/16] init_latent = sd_ctx->sd->generate_init_latent(work_ctx, width, height, frames, true); denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1); ggml_set_f32(denoise_mask, 1.f); - if (!sd_ctx->sd->use_tiny_autoencoder) - sd_ctx->sd->process_latent_out(init_latent); + init_latent = sd_ctx->sd->first_stage_model->diffusion_to_vae_latents(work_ctx, init_latent); ggml_ext_tensor_iter(init_image_latent, [&](ggml_tensor* t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { float value = ggml_ext_tensor_get_f32(t, i0, i1, i2, i3); @@ -4122,8 +3722,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s } }); - if (!sd_ctx->sd->use_tiny_autoencoder) - sd_ctx->sd->process_latent_in(init_latent); + init_latent = sd_ctx->sd->first_stage_model->vae_to_diffuison_latents(work_ctx, init_latent); int64_t t2 = ggml_time_ms(); LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1); @@ -4346,7 +3945,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s struct ggml_tensor* vid = sd_ctx->sd->decode_first_stage(work_ctx, final_latent, true); int64_t t5 = ggml_time_ms(); LOG_INFO("decode_first_stage completed, taking %.2fs", (t5 - t4) * 1.0f / 1000); - if (sd_ctx->sd->free_params_immediately && !sd_ctx->sd->use_tiny_autoencoder) { + if (sd_ctx->sd->free_params_immediately) { sd_ctx->sd->first_stage_model->free_params_buffer(); } diff --git a/src/tae.hpp b/src/tae.hpp index 8315257..60df7b2 100644 --- a/src/tae.hpp +++ b/src/tae.hpp @@ -442,11 +442,13 @@ protected: bool decode_only; SDVersion version; +public: + int z_channels = 16; + public: TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2) : decode_only(decode_only), version(version) { - int z_channels = 16; - int patch = 1; + int patch = 1; if (version == VERSION_WAN2_2_TI2V) { z_channels = 48; patch = 2; @@ -494,10 +496,12 @@ protected: bool decode_only; bool taef2 = false; +public: + int z_channels = 4; + public: TAESD(bool decode_only = true, SDVersion version = VERSION_SD1) : decode_only(decode_only) { - int z_channels = 4; bool use_midblock_gn = false; taef2 = sd_version_is_flux2(version); @@ -533,20 +537,7 @@ public: } }; -struct TinyAutoEncoder : public GGMLRunner { - TinyAutoEncoder(ggml_backend_t backend, bool offload_params_to_cpu) - : GGMLRunner(backend, offload_params_to_cpu) {} - virtual bool compute(const int n_threads, - struct ggml_tensor* z, - bool decode_graph, - struct ggml_tensor** output, - struct ggml_context* output_ctx = nullptr) = 0; - - virtual bool load_from_file(const std::string& file_path, int n_threads) = 0; - virtual void get_param_tensors(std::map& tensors, const std::string prefix) = 0; -}; - -struct TinyImageAutoEncoder : public TinyAutoEncoder { +struct TinyImageAutoEncoder : public VAE { TAESD taesd; bool decode_only = false; @@ -558,7 +549,8 @@ struct TinyImageAutoEncoder : public TinyAutoEncoder { SDVersion version = VERSION_SD1) : decode_only(decoder_only), taesd(decoder_only, version), - TinyAutoEncoder(backend, offload_params_to_cpu) { + VAE(version, backend, offload_params_to_cpu) { + scale_input = false; taesd.init(params_ctx, tensor_storage_map, prefix); } @@ -566,37 +558,26 @@ struct TinyImageAutoEncoder : public TinyAutoEncoder { return "taesd"; } - bool load_from_file(const std::string& file_path, int n_threads) { - LOG_INFO("loading taesd from '%s', decode_only = %s", file_path.c_str(), decode_only ? "true" : "false"); - alloc_params_buffer(); - std::map taesd_tensors; - taesd.get_param_tensors(taesd_tensors); - std::set ignore_tensors; - if (decode_only) { - ignore_tensors.insert("encoder."); - } - - ModelLoader model_loader; - if (!model_loader.init_from_file_and_convert_name(file_path)) { - LOG_ERROR("init taesd model loader from file failed: '%s'", file_path.c_str()); - return false; - } - - bool success = model_loader.load_tensors(taesd_tensors, ignore_tensors, n_threads); - - if (!success) { - LOG_ERROR("load tae tensors from model loader failed"); - return false; - } - - LOG_INFO("taesd model loaded"); - return success; - } - void get_param_tensors(std::map& tensors, const std::string prefix) { taesd.get_param_tensors(tensors, prefix); } + ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr rng) { + return vae_output; + } + + ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) { + return ggml_ext_dup_and_cpy_tensor(work_ctx, latents); + } + + ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) { + return ggml_ext_dup_and_cpy_tensor(work_ctx, latents); + } + + int get_encoder_output_channels(int input_channels) { + return taesd.z_channels; + } + struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) { struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); z = to_backend(z); @@ -606,11 +587,11 @@ struct TinyImageAutoEncoder : public TinyAutoEncoder { return gf; } - bool compute(const int n_threads, - struct ggml_tensor* z, - bool decode_graph, - struct ggml_tensor** output, - struct ggml_context* output_ctx = nullptr) { + bool _compute(const int n_threads, + struct ggml_tensor* z, + bool decode_graph, + struct ggml_tensor** output, + struct ggml_context* output_ctx = nullptr) { auto get_graph = [&]() -> struct ggml_cgraph* { return build_graph(z, decode_graph); }; @@ -619,7 +600,7 @@ struct TinyImageAutoEncoder : public TinyAutoEncoder { } }; -struct TinyVideoAutoEncoder : public TinyAutoEncoder { +struct TinyVideoAutoEncoder : public VAE { TAEHV taehv; bool decode_only = false; @@ -631,7 +612,8 @@ struct TinyVideoAutoEncoder : public TinyAutoEncoder { SDVersion version = VERSION_WAN2) : decode_only(decoder_only), taehv(decoder_only, version), - TinyAutoEncoder(backend, offload_params_to_cpu) { + VAE(version, backend, offload_params_to_cpu) { + scale_input = false; taehv.init(params_ctx, tensor_storage_map, prefix); } @@ -639,37 +621,26 @@ struct TinyVideoAutoEncoder : public TinyAutoEncoder { return "taehv"; } - bool load_from_file(const std::string& file_path, int n_threads) { - LOG_INFO("loading taehv from '%s', decode_only = %s", file_path.c_str(), decode_only ? "true" : "false"); - alloc_params_buffer(); - std::map taehv_tensors; - taehv.get_param_tensors(taehv_tensors); - std::set ignore_tensors; - if (decode_only) { - ignore_tensors.insert("encoder."); - } - - ModelLoader model_loader; - if (!model_loader.init_from_file(file_path)) { - LOG_ERROR("init taehv model loader from file failed: '%s'", file_path.c_str()); - return false; - } - - bool success = model_loader.load_tensors(taehv_tensors, ignore_tensors, n_threads); - - if (!success) { - LOG_ERROR("load tae tensors from model loader failed"); - return false; - } - - LOG_INFO("taehv model loaded"); - return success; - } - void get_param_tensors(std::map& tensors, const std::string prefix) { taehv.get_param_tensors(tensors, prefix); } + ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr rng) { + return vae_output; + } + + ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) { + return ggml_ext_dup_and_cpy_tensor(work_ctx, latents); + } + + ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) { + return ggml_ext_dup_and_cpy_tensor(work_ctx, latents); + } + + int get_encoder_output_channels(int input_channels) { + return taehv.z_channels; + } + struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) { struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); z = to_backend(z); @@ -679,11 +650,11 @@ struct TinyVideoAutoEncoder : public TinyAutoEncoder { return gf; } - bool compute(const int n_threads, - struct ggml_tensor* z, - bool decode_graph, - struct ggml_tensor** output, - struct ggml_context* output_ctx = nullptr) { + bool _compute(const int n_threads, + struct ggml_tensor* z, + bool decode_graph, + struct ggml_tensor** output, + struct ggml_context* output_ctx = nullptr) { auto get_graph = [&]() -> struct ggml_cgraph* { return build_graph(z, decode_graph); }; diff --git a/src/vae.hpp b/src/vae.hpp index 7ccba6e..ad83e01 100644 --- a/src/vae.hpp +++ b/src/vae.hpp @@ -3,635 +3,206 @@ #include "common_block.hpp" -/*================================================== AutoEncoderKL ===================================================*/ - -#define VAE_GRAPH_SIZE 20480 - -class ResnetBlock : public UnaryBlock { -protected: - int64_t in_channels; - int64_t out_channels; - -public: - ResnetBlock(int64_t in_channels, - int64_t out_channels) - : in_channels(in_channels), - out_channels(out_channels) { - // temb_channels is always 0 - blocks["norm1"] = std::shared_ptr(new GroupNorm32(in_channels)); - blocks["conv1"] = std::shared_ptr(new Conv2d(in_channels, out_channels, {3, 3}, {1, 1}, {1, 1})); - - blocks["norm2"] = std::shared_ptr(new GroupNorm32(out_channels)); - blocks["conv2"] = std::shared_ptr(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1})); - - if (out_channels != in_channels) { - blocks["nin_shortcut"] = std::shared_ptr(new Conv2d(in_channels, out_channels, {1, 1})); - } - } - - struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { - // x: [N, in_channels, h, w] - // t_emb is always None - auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); - auto conv1 = std::dynamic_pointer_cast(blocks["conv1"]); - auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]); - auto conv2 = std::dynamic_pointer_cast(blocks["conv2"]); - - auto h = x; - h = norm1->forward(ctx, h); - h = ggml_silu_inplace(ctx->ggml_ctx, h); // swish - h = conv1->forward(ctx, h); - // return h; - - h = norm2->forward(ctx, h); - h = ggml_silu_inplace(ctx->ggml_ctx, h); // swish - // dropout, skip for inference - h = conv2->forward(ctx, h); - - // skip connection - if (out_channels != in_channels) { - auto nin_shortcut = std::dynamic_pointer_cast(blocks["nin_shortcut"]); - - x = nin_shortcut->forward(ctx, x); // [N, out_channels, h, w] - } - - h = ggml_add(ctx->ggml_ctx, h, x); - return h; // [N, out_channels, h, w] - } -}; - -class AttnBlock : public UnaryBlock { -protected: - int64_t in_channels; - bool use_linear; - - void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") { - auto iter = tensor_storage_map.find(prefix + "proj_out.weight"); - if (iter != tensor_storage_map.end()) { - if (iter->second.n_dims == 4 && use_linear) { - use_linear = false; - blocks["q"] = std::make_shared(in_channels, in_channels, std::pair{1, 1}); - blocks["k"] = std::make_shared(in_channels, in_channels, std::pair{1, 1}); - blocks["v"] = std::make_shared(in_channels, in_channels, std::pair{1, 1}); - blocks["proj_out"] = std::make_shared(in_channels, in_channels, std::pair{1, 1}); - } else if (iter->second.n_dims == 2 && !use_linear) { - use_linear = true; - blocks["q"] = std::make_shared(in_channels, in_channels); - blocks["k"] = std::make_shared(in_channels, in_channels); - blocks["v"] = std::make_shared(in_channels, in_channels); - blocks["proj_out"] = std::make_shared(in_channels, in_channels); - } - } - } - -public: - AttnBlock(int64_t in_channels, bool use_linear) - : in_channels(in_channels), use_linear(use_linear) { - blocks["norm"] = std::shared_ptr(new GroupNorm32(in_channels)); - if (use_linear) { - blocks["q"] = std::shared_ptr(new Linear(in_channels, in_channels)); - blocks["k"] = std::shared_ptr(new Linear(in_channels, in_channels)); - blocks["v"] = std::shared_ptr(new Linear(in_channels, in_channels)); - blocks["proj_out"] = std::shared_ptr(new Linear(in_channels, in_channels)); - } else { - blocks["q"] = std::shared_ptr(new Conv2d(in_channels, in_channels, {1, 1})); - blocks["k"] = std::shared_ptr(new Conv2d(in_channels, in_channels, {1, 1})); - blocks["v"] = std::shared_ptr(new Conv2d(in_channels, in_channels, {1, 1})); - blocks["proj_out"] = std::shared_ptr(new Conv2d(in_channels, in_channels, {1, 1})); - } - } - - struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { - // x: [N, in_channels, h, w] - auto norm = std::dynamic_pointer_cast(blocks["norm"]); - auto q_proj = std::dynamic_pointer_cast(blocks["q"]); - auto k_proj = std::dynamic_pointer_cast(blocks["k"]); - auto v_proj = std::dynamic_pointer_cast(blocks["v"]); - auto proj_out = std::dynamic_pointer_cast(blocks["proj_out"]); - - auto h_ = norm->forward(ctx, x); - - const int64_t n = h_->ne[3]; - const int64_t c = h_->ne[2]; - const int64_t h = h_->ne[1]; - const int64_t w = h_->ne[0]; - - ggml_tensor* q; - ggml_tensor* k; - ggml_tensor* v; - if (use_linear) { - h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 2, 0, 3)); // [N, h, w, in_channels] - h_ = ggml_reshape_3d(ctx->ggml_ctx, h_, c, h * w, n); // [N, h * w, in_channels] - - q = q_proj->forward(ctx, h_); // [N, h * w, in_channels] - k = k_proj->forward(ctx, h_); // [N, h * w, in_channels] - v = v_proj->forward(ctx, h_); // [N, h * w, in_channels] - } else { - q = q_proj->forward(ctx, h_); // [N, in_channels, h, w] - q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels] - q = ggml_reshape_3d(ctx->ggml_ctx, q, c, h * w, n); // [N, h * w, in_channels] - - k = k_proj->forward(ctx, h_); // [N, in_channels, h, w] - k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels] - k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [N, h * w, in_channels] - - v = v_proj->forward(ctx, h_); // [N, in_channels, h, w] - v = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, v, 1, 2, 0, 3)); // [N, h, w, in_channels] - v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels] - } - - h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, ctx->flash_attn_enabled); - - if (use_linear) { - h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels] - - h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w] - h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w] - } else { - h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w] - h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w] - - h_ = proj_out->forward(ctx, h_); // [N, in_channels, h, w] - } - - h_ = ggml_add(ctx->ggml_ctx, h_, x); - return h_; - } -}; - -class AE3DConv : public Conv2d { -public: - AE3DConv(int64_t in_channels, - int64_t out_channels, - std::pair kernel_size, - int video_kernel_size = 3, - std::pair stride = {1, 1}, - std::pair padding = {0, 0}, - std::pair dilation = {1, 1}, - bool bias = true) - : Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias) { - int kernel_padding = video_kernel_size / 2; - blocks["time_mix_conv"] = std::shared_ptr(new Conv3d(out_channels, - out_channels, - {video_kernel_size, 1, 1}, - {1, 1, 1}, - {kernel_padding, 0, 0})); - } - - struct ggml_tensor* forward(GGMLRunnerContext* ctx, - struct ggml_tensor* x) override { - // timesteps always None - // skip_video always False - // x: [N, IC, IH, IW] - // result: [N, OC, OH, OW] - auto time_mix_conv = std::dynamic_pointer_cast(blocks["time_mix_conv"]); - - x = Conv2d::forward(ctx, x); - // timesteps = x.shape[0] - // x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) - // x = conv3d(x) - // return rearrange(x, "b c t h w -> (b t) c h w") - int64_t T = x->ne[3]; - int64_t B = x->ne[3] / T; - int64_t C = x->ne[2]; - int64_t H = x->ne[1]; - int64_t W = x->ne[0]; - - x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w) - x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w) - x = time_mix_conv->forward(ctx, x); // [B, OC, T, OH * OW] - x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w) - x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w - return x; // [B*T, OC, OH, OW] - } -}; - -class VideoResnetBlock : public ResnetBlock { -protected: - void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { - enum ggml_type wtype = get_type(prefix + "mix_factor", tensor_storage_map, GGML_TYPE_F32); - params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1); - } - - float get_alpha() { - float alpha = ggml_ext_backend_tensor_get_f32(params["mix_factor"]); - return sigmoid(alpha); - } - -public: - VideoResnetBlock(int64_t in_channels, - int64_t out_channels, - int video_kernel_size = 3) - : ResnetBlock(in_channels, out_channels) { - // merge_strategy is always learned - blocks["time_stack"] = std::shared_ptr(new ResBlock(out_channels, 0, out_channels, {video_kernel_size, 1}, 3, false, true)); - } - - struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { - // x: [N, in_channels, h, w] aka [b*t, in_channels, h, w] - // return: [N, out_channels, h, w] aka [b*t, out_channels, h, w] - // t_emb is always None - // skip_video is always False - // timesteps is always None - auto time_stack = std::dynamic_pointer_cast(blocks["time_stack"]); - - x = ResnetBlock::forward(ctx, x); // [N, out_channels, h, w] - // return x; - - int64_t T = x->ne[3]; - int64_t B = x->ne[3] / T; - int64_t C = x->ne[2]; - int64_t H = x->ne[1]; - int64_t W = x->ne[0]; - - x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w) - x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w) - auto x_mix = x; - - x = time_stack->forward(ctx, x); // b t c (h w) - - float alpha = get_alpha(); - x = ggml_add(ctx->ggml_ctx, - ggml_ext_scale(ctx->ggml_ctx, x, alpha), - ggml_ext_scale(ctx->ggml_ctx, x_mix, 1.0f - alpha)); - - x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w) - x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w - - return x; - } -}; - -// ldm.modules.diffusionmodules.model.Encoder -class Encoder : public GGMLBlock { -protected: - int ch = 128; - std::vector ch_mult = {1, 2, 4, 4}; - int num_res_blocks = 2; - int in_channels = 3; - int z_channels = 4; - bool double_z = true; - -public: - Encoder(int ch, - std::vector ch_mult, - int num_res_blocks, - int in_channels, - int z_channels, - bool double_z = true, - bool use_linear_projection = false) - : ch(ch), - ch_mult(ch_mult), - num_res_blocks(num_res_blocks), - in_channels(in_channels), - z_channels(z_channels), - double_z(double_z) { - blocks["conv_in"] = std::shared_ptr(new Conv2d(in_channels, ch, {3, 3}, {1, 1}, {1, 1})); - - size_t num_resolutions = ch_mult.size(); - - int block_in = 1; - for (int i = 0; i < num_resolutions; i++) { - if (i == 0) { - block_in = ch; - } else { - block_in = ch * ch_mult[i - 1]; - } - int block_out = ch * ch_mult[i]; - for (int j = 0; j < num_res_blocks; j++) { - std::string name = "down." + std::to_string(i) + ".block." + std::to_string(j); - blocks[name] = std::shared_ptr(new ResnetBlock(block_in, block_out)); - block_in = block_out; - } - if (i != num_resolutions - 1) { - std::string name = "down." + std::to_string(i) + ".downsample"; - blocks[name] = std::shared_ptr(new DownSampleBlock(block_in, block_in, true)); - } - } - - blocks["mid.block_1"] = std::shared_ptr(new ResnetBlock(block_in, block_in)); - blocks["mid.attn_1"] = std::shared_ptr(new AttnBlock(block_in, use_linear_projection)); - blocks["mid.block_2"] = std::shared_ptr(new ResnetBlock(block_in, block_in)); - - blocks["norm_out"] = std::shared_ptr(new GroupNorm32(block_in)); - blocks["conv_out"] = std::shared_ptr(new Conv2d(block_in, double_z ? z_channels * 2 : z_channels, {3, 3}, {1, 1}, {1, 1})); - } - - virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { - // x: [N, in_channels, h, w] - - auto conv_in = std::dynamic_pointer_cast(blocks["conv_in"]); - auto mid_block_1 = std::dynamic_pointer_cast(blocks["mid.block_1"]); - auto mid_attn_1 = std::dynamic_pointer_cast(blocks["mid.attn_1"]); - auto mid_block_2 = std::dynamic_pointer_cast(blocks["mid.block_2"]); - auto norm_out = std::dynamic_pointer_cast(blocks["norm_out"]); - auto conv_out = std::dynamic_pointer_cast(blocks["conv_out"]); - - auto h = conv_in->forward(ctx, x); // [N, ch, h, w] - - // downsampling - size_t num_resolutions = ch_mult.size(); - for (int i = 0; i < num_resolutions; i++) { - for (int j = 0; j < num_res_blocks; j++) { - std::string name = "down." + std::to_string(i) + ".block." + std::to_string(j); - auto down_block = std::dynamic_pointer_cast(blocks[name]); - - h = down_block->forward(ctx, h); - } - if (i != num_resolutions - 1) { - std::string name = "down." + std::to_string(i) + ".downsample"; - auto down_sample = std::dynamic_pointer_cast(blocks[name]); - - h = down_sample->forward(ctx, h); - } - } - - // middle - h = mid_block_1->forward(ctx, h); - h = mid_attn_1->forward(ctx, h); - h = mid_block_2->forward(ctx, h); // [N, block_in, h, w] - - // end - h = norm_out->forward(ctx, h); - h = ggml_silu_inplace(ctx->ggml_ctx, h); // nonlinearity/swish - h = conv_out->forward(ctx, h); // [N, z_channels*2, h, w] - return h; - } -}; - -// ldm.modules.diffusionmodules.model.Decoder -class Decoder : public GGMLBlock { -protected: - int ch = 128; - int out_ch = 3; - std::vector ch_mult = {1, 2, 4, 4}; - int num_res_blocks = 2; - int z_channels = 4; - bool video_decoder = false; - int video_kernel_size = 3; - - virtual std::shared_ptr get_conv_out(int64_t in_channels, - int64_t out_channels, - std::pair kernel_size, - std::pair stride = {1, 1}, - std::pair padding = {0, 0}) { - if (video_decoder) { - return std::shared_ptr(new AE3DConv(in_channels, out_channels, kernel_size, video_kernel_size, stride, padding)); - } else { - return std::shared_ptr(new Conv2d(in_channels, out_channels, kernel_size, stride, padding)); - } - } - - virtual std::shared_ptr get_resnet_block(int64_t in_channels, - int64_t out_channels) { - if (video_decoder) { - return std::shared_ptr(new VideoResnetBlock(in_channels, out_channels, video_kernel_size)); - } else { - return std::shared_ptr(new ResnetBlock(in_channels, out_channels)); - } - } - -public: - Decoder(int ch, - int out_ch, - std::vector ch_mult, - int num_res_blocks, - int z_channels, - bool use_linear_projection = false, - bool video_decoder = false, - int video_kernel_size = 3) - : ch(ch), - out_ch(out_ch), - ch_mult(ch_mult), - num_res_blocks(num_res_blocks), - z_channels(z_channels), - video_decoder(video_decoder), - video_kernel_size(video_kernel_size) { - int num_resolutions = static_cast(ch_mult.size()); - int block_in = ch * ch_mult[num_resolutions - 1]; - - blocks["conv_in"] = std::shared_ptr(new Conv2d(z_channels, block_in, {3, 3}, {1, 1}, {1, 1})); - - blocks["mid.block_1"] = get_resnet_block(block_in, block_in); - blocks["mid.attn_1"] = std::shared_ptr(new AttnBlock(block_in, use_linear_projection)); - blocks["mid.block_2"] = get_resnet_block(block_in, block_in); - - for (int i = num_resolutions - 1; i >= 0; i--) { - int mult = ch_mult[i]; - int block_out = ch * mult; - for (int j = 0; j < num_res_blocks + 1; j++) { - std::string name = "up." + std::to_string(i) + ".block." + std::to_string(j); - blocks[name] = get_resnet_block(block_in, block_out); - - block_in = block_out; - } - if (i != 0) { - std::string name = "up." + std::to_string(i) + ".upsample"; - blocks[name] = std::shared_ptr(new UpSampleBlock(block_in, block_in)); - } - } - - blocks["norm_out"] = std::shared_ptr(new GroupNorm32(block_in)); - blocks["conv_out"] = get_conv_out(block_in, out_ch, {3, 3}, {1, 1}, {1, 1}); - } - - virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) { - // z: [N, z_channels, h, w] - // alpha is always 0 - // merge_strategy is always learned - // time_mode is always conv-only, so we need to replace conv_out_op/resnet_op to AE3DConv/VideoResBlock - // AttnVideoBlock will not be used - auto conv_in = std::dynamic_pointer_cast(blocks["conv_in"]); - auto mid_block_1 = std::dynamic_pointer_cast(blocks["mid.block_1"]); - auto mid_attn_1 = std::dynamic_pointer_cast(blocks["mid.attn_1"]); - auto mid_block_2 = std::dynamic_pointer_cast(blocks["mid.block_2"]); - auto norm_out = std::dynamic_pointer_cast(blocks["norm_out"]); - auto conv_out = std::dynamic_pointer_cast(blocks["conv_out"]); - - // conv_in - auto h = conv_in->forward(ctx, z); // [N, block_in, h, w] - - // middle - h = mid_block_1->forward(ctx, h); - // return h; - - h = mid_attn_1->forward(ctx, h); - h = mid_block_2->forward(ctx, h); // [N, block_in, h, w] - - // upsampling - int num_resolutions = static_cast(ch_mult.size()); - for (int i = num_resolutions - 1; i >= 0; i--) { - for (int j = 0; j < num_res_blocks + 1; j++) { - std::string name = "up." + std::to_string(i) + ".block." + std::to_string(j); - auto up_block = std::dynamic_pointer_cast(blocks[name]); - - h = up_block->forward(ctx, h); - } - if (i != 0) { - std::string name = "up." + std::to_string(i) + ".upsample"; - auto up_sample = std::dynamic_pointer_cast(blocks[name]); - - h = up_sample->forward(ctx, h); - } - } - - h = norm_out->forward(ctx, h); - h = ggml_silu_inplace(ctx->ggml_ctx, h); // nonlinearity/swish - h = conv_out->forward(ctx, h); // [N, out_ch, h*8, w*8] - return h; - } -}; - -// ldm.models.autoencoder.AutoencoderKL -class AutoencodingEngine : public GGMLBlock { +struct VAE : public GGMLRunner { protected: SDVersion version; - bool decode_only = true; - bool use_video_decoder = false; - bool use_quant = true; - int embed_dim = 4; - struct { - int z_channels = 4; - int resolution = 256; - int in_channels = 3; - int out_ch = 3; - int ch = 128; - std::vector ch_mult = {1, 2, 4, 4}; - int num_res_blocks = 2; - bool double_z = true; - } dd_config; + bool scale_input = true; + virtual bool _compute(const int n_threads, + struct ggml_tensor* z, + bool decode_graph, + struct ggml_tensor** output, + struct ggml_context* output_ctx) = 0; public: - AutoencodingEngine(SDVersion version = VERSION_SD1, - bool decode_only = true, - bool use_linear_projection = false, - bool use_video_decoder = false) - : version(version), decode_only(decode_only), use_video_decoder(use_video_decoder) { - if (sd_version_is_dit(version)) { - if (sd_version_is_flux2(version)) { - dd_config.z_channels = 32; - embed_dim = 32; + VAE(SDVersion version, ggml_backend_t backend, bool offload_params_to_cpu) + : version(version), GGMLRunner(backend, offload_params_to_cpu) {} + + int get_scale_factor() { + int scale_factor = 8; + if (version == VERSION_WAN2_2_TI2V) { + scale_factor = 16; + } else if (sd_version_is_flux2(version)) { + scale_factor = 16; + } else if (version == VERSION_CHROMA_RADIANCE) { + scale_factor = 1; + } + return scale_factor; + } + + virtual int get_encoder_output_channels(int input_channels) = 0; + + void get_tile_sizes(int& tile_size_x, + int& tile_size_y, + float& tile_overlap, + const sd_tiling_params_t& params, + int64_t latent_x, + int64_t latent_y, + float encoding_factor = 1.0f) { + tile_overlap = std::max(std::min(params.target_overlap, 0.5f), 0.0f); + auto get_tile_size = [&](int requested_size, float factor, int64_t latent_size) { + const int default_tile_size = 32; + const int min_tile_dimension = 4; + int tile_size = default_tile_size; + // factor <= 1 means simple fraction of the latent dimension + // factor > 1 means number of tiles across that dimension + if (factor > 0.f) { + if (factor > 1.0) + factor = 1 / (factor - factor * tile_overlap + tile_overlap); + tile_size = static_cast(std::round(latent_size * factor)); + } else if (requested_size >= min_tile_dimension) { + tile_size = requested_size; + } + tile_size = static_cast(tile_size * encoding_factor); + return std::max(std::min(tile_size, static_cast(latent_size)), min_tile_dimension); + }; + + tile_size_x = get_tile_size(params.tile_size_x, params.rel_size_x, latent_x); + tile_size_y = get_tile_size(params.tile_size_y, params.rel_size_y, latent_y); + } + + ggml_tensor* encode(int n_threads, + ggml_context* work_ctx, + ggml_tensor* x, + sd_tiling_params_t tiling_params, + bool circular_x = false, + bool circular_y = false) { + int64_t t0 = ggml_time_ms(); + ggml_tensor* result = nullptr; + const int scale_factor = get_scale_factor(); + int64_t W = x->ne[0] / scale_factor; + int64_t H = x->ne[1] / scale_factor; + int channel_dim = sd_version_is_wan(version) ? 3 : 2; + int64_t C = get_encoder_output_channels(static_cast(x->ne[channel_dim])); + int64_t ne2; + int64_t ne3; + if (sd_version_is_wan(version)) { + int64_t T = x->ne[2]; + ne2 = (T - 1) / 4 + 1; + ne3 = C; + } else { + ne2 = C; + ne3 = x->ne[3]; + } + result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, ne2, ne3); + + if (scale_input) { + scale_to_minus1_1(x); + } + + if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) { + x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]); + } + + if (tiling_params.enabled) { + float tile_overlap; + int tile_size_x, tile_size_y; + // multiply tile size for encode to keep the compute buffer size consistent + get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, tiling_params, W, H, 1.30539f); + + LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y); + + auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { + return _compute(n_threads, in, false, &out, work_ctx); + }; + sd_tiling_non_square(x, result, scale_factor, tile_size_x, tile_size_y, tile_overlap, circular_x, circular_y, on_tiling); + } else { + _compute(n_threads, x, false, &result, work_ctx); + } + free_compute_buffer(); + + int64_t t1 = ggml_time_ms(); + LOG_DEBUG("computing vae encode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); + return result; + } + + ggml_tensor* decode(int n_threads, + ggml_context* work_ctx, + ggml_tensor* x, + sd_tiling_params_t tiling_params, + bool decode_video = false, + bool circular_x = false, + bool circular_y = false, + ggml_tensor* result = nullptr, + bool silent = false) { + const int scale_factor = get_scale_factor(); + int64_t W = x->ne[0] * scale_factor; + int64_t H = x->ne[1] * scale_factor; + int64_t C = 3; + if (result == nullptr) { + if (decode_video) { + int64_t T = x->ne[2]; + if (sd_version_is_wan(version)) { + T = ((T - 1) * 4) + 1; + } + result = ggml_new_tensor_4d(work_ctx, + GGML_TYPE_F32, + W, + H, + T, + 3); } else { - use_quant = false; - dd_config.z_channels = 16; + result = ggml_new_tensor_4d(work_ctx, + GGML_TYPE_F32, + W, + H, + C, + x->ne[3]); } } - if (use_video_decoder) { - use_quant = false; + int64_t t0 = ggml_time_ms(); + if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) { + x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]); } - blocks["decoder"] = std::shared_ptr(new Decoder(dd_config.ch, - dd_config.out_ch, - dd_config.ch_mult, - dd_config.num_res_blocks, - dd_config.z_channels, - use_linear_projection, - use_video_decoder)); - if (use_quant) { - blocks["post_quant_conv"] = std::shared_ptr(new Conv2d(dd_config.z_channels, - embed_dim, - {1, 1})); - } - if (!decode_only) { - blocks["encoder"] = std::shared_ptr(new Encoder(dd_config.ch, - dd_config.ch_mult, - dd_config.num_res_blocks, - dd_config.in_channels, - dd_config.z_channels, - dd_config.double_z, - use_linear_projection)); - if (use_quant) { - int factor = dd_config.double_z ? 2 : 1; + if (tiling_params.enabled) { + float tile_overlap; + int tile_size_x, tile_size_y; + get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, tiling_params, x->ne[0], x->ne[1]); - blocks["quant_conv"] = std::shared_ptr(new Conv2d(embed_dim * factor, - dd_config.z_channels * factor, - {1, 1})); + if (!silent) { + LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y); + } + + auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { + return _compute(n_threads, in, true, &out, nullptr); + }; + sd_tiling_non_square(x, result, scale_factor, tile_size_x, tile_size_y, tile_overlap, circular_x, circular_y, on_tiling, silent); + } else { + if (!_compute(n_threads, x, true, &result, work_ctx)) { + LOG_ERROR("Failed to decode latetnts"); + free_compute_buffer(); + return nullptr; } } + free_compute_buffer(); + if (scale_input) { + scale_to_0_1(result); + } + int64_t t1 = ggml_time_ms(); + LOG_DEBUG("computing vae decode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); + ggml_ext_tensor_clamp_inplace(result, 0.0f, 1.0f); + return result; } - struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) { - // z: [N, z_channels, h, w] - if (sd_version_is_flux2(version)) { - // [N, C*p*p, h, w] -> [N, C, h*p, w*p] - int64_t p = 2; - - int64_t N = z->ne[3]; - int64_t C = z->ne[2] / p / p; - int64_t h = z->ne[1]; - int64_t w = z->ne[0]; - int64_t H = h * p; - int64_t W = w * p; - - z = ggml_reshape_4d(ctx->ggml_ctx, z, w * h, p * p, C, N); // [N, C, p*p, h*w] - z = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, z, 1, 0, 2, 3)); // [N, C, h*w, p*p] - z = ggml_reshape_4d(ctx->ggml_ctx, z, p, p, w, h * C * N); // [N*C*h, w, p, p] - z = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, z, 0, 2, 1, 3)); // [N*C*h, p, w, p] - z = ggml_reshape_4d(ctx->ggml_ctx, z, W, H, C, N); // [N, C, h*p, w*p] - } - - if (use_quant) { - auto post_quant_conv = std::dynamic_pointer_cast(blocks["post_quant_conv"]); - z = post_quant_conv->forward(ctx, z); // [N, z_channels, h, w] - } - auto decoder = std::dynamic_pointer_cast(blocks["decoder"]); - - ggml_set_name(z, "bench-start"); - auto h = decoder->forward(ctx, z); - ggml_set_name(h, "bench-end"); - return h; - } - - struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) { - // x: [N, in_channels, h, w] - auto encoder = std::dynamic_pointer_cast(blocks["encoder"]); - - auto z = encoder->forward(ctx, x); // [N, 2*z_channels, h/8, w/8] - if (use_quant) { - auto quant_conv = std::dynamic_pointer_cast(blocks["quant_conv"]); - z = quant_conv->forward(ctx, z); // [N, 2*embed_dim, h/8, w/8] - } - if (sd_version_is_flux2(version)) { - z = ggml_ext_chunk(ctx->ggml_ctx, z, 2, 2)[0]; - - // [N, C, H, W] -> [N, C*p*p, H/p, W/p] - int64_t p = 2; - int64_t N = z->ne[3]; - int64_t C = z->ne[2]; - int64_t H = z->ne[1]; - int64_t W = z->ne[0]; - int64_t h = H / p; - int64_t w = W / p; - - z = ggml_reshape_4d(ctx->ggml_ctx, z, p, w, p, h * C * N); // [N*C*h, p, w, p] - z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 0, 2, 1, 3)); // [N*C*h, w, p, p] - z = ggml_reshape_4d(ctx->ggml_ctx, z, p * p, w * h, C, N); // [N, C, h*w, p*p] - z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 1, 0, 2, 3)); // [N, C, p*p, h*w] - z = ggml_reshape_4d(ctx->ggml_ctx, z, w, h, p * p * C, N); // [N, C*p*p, h*w] - } - return z; - } -}; - -struct VAE : public GGMLRunner { - VAE(ggml_backend_t backend, bool offload_params_to_cpu) - : GGMLRunner(backend, offload_params_to_cpu) {} - virtual bool compute(const int n_threads, - struct ggml_tensor* z, - bool decode_graph, - struct ggml_tensor** output, - struct ggml_context* output_ctx) = 0; - virtual void get_param_tensors(std::map& tensors, const std::string prefix) = 0; + virtual ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr rng) = 0; + virtual ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) = 0; + virtual ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) = 0; + virtual void get_param_tensors(std::map& tensors, const std::string prefix) = 0; virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); }; }; struct FakeVAE : public VAE { - FakeVAE(ggml_backend_t backend, bool offload_params_to_cpu) - : VAE(backend, offload_params_to_cpu) {} - bool compute(const int n_threads, - struct ggml_tensor* z, - bool decode_graph, - struct ggml_tensor** output, - struct ggml_context* output_ctx) override { + FakeVAE(SDVersion version, ggml_backend_t backend, bool offload_params_to_cpu) + : VAE(version, backend, offload_params_to_cpu) {} + + int get_encoder_output_channels(int input_channels) { + return input_channels; + } + + bool _compute(const int n_threads, + struct ggml_tensor* z, + bool decode_graph, + struct ggml_tensor** output, + struct ggml_context* output_ctx) override { if (*output == nullptr && output_ctx != nullptr) { *output = ggml_dup_tensor(output_ctx, z); } @@ -642,6 +213,18 @@ struct FakeVAE : public VAE { return true; } + ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr rng) { + return vae_output; + } + + ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) { + return ggml_ext_dup_and_cpy_tensor(work_ctx, latents); + } + + ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) { + return ggml_ext_dup_and_cpy_tensor(work_ctx, latents); + } + void get_param_tensors(std::map& tensors, const std::string prefix) override {} std::string get_desc() override { @@ -649,126 +232,4 @@ struct FakeVAE : public VAE { } }; -struct AutoEncoderKL : public VAE { - bool decode_only = true; - AutoencodingEngine ae; - - AutoEncoderKL(ggml_backend_t backend, - bool offload_params_to_cpu, - const String2TensorStorage& tensor_storage_map, - const std::string prefix, - bool decode_only = false, - bool use_video_decoder = false, - SDVersion version = VERSION_SD1) - : decode_only(decode_only), VAE(backend, offload_params_to_cpu) { - bool use_linear_projection = false; - for (const auto& [name, tensor_storage] : tensor_storage_map) { - if (!starts_with(name, prefix)) { - continue; - } - if (ends_with(name, "attn_1.proj_out.weight")) { - if (tensor_storage.n_dims == 2) { - use_linear_projection = true; - } - break; - } - } - ae = AutoencodingEngine(version, decode_only, use_linear_projection, use_video_decoder); - ae.init(params_ctx, tensor_storage_map, prefix); - } - - void set_conv2d_scale(float scale) override { - std::vector blocks; - ae.get_all_blocks(blocks); - for (auto block : blocks) { - if (block->get_desc() == "Conv2d") { - auto conv_block = (Conv2d*)block; - conv_block->set_scale(scale); - } - } - } - - std::string get_desc() override { - return "vae"; - } - - void get_param_tensors(std::map& tensors, const std::string prefix) override { - ae.get_param_tensors(tensors, prefix); - } - - struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) { - struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); - - z = to_backend(z); - - auto runner_ctx = get_context(); - - struct ggml_tensor* out = decode_graph ? ae.decode(&runner_ctx, z) : ae.encode(&runner_ctx, z); - - ggml_build_forward_expand(gf, out); - - return gf; - } - - bool compute(const int n_threads, - struct ggml_tensor* z, - bool decode_graph, - struct ggml_tensor** output, - struct ggml_context* output_ctx = nullptr) override { - GGML_ASSERT(!decode_only || decode_graph); - auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(z, decode_graph); - }; - // ggml_set_f32(z, 0.5f); - // print_ggml_tensor(z); - return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); - } - - void test() { - struct ggml_init_params params; - params.mem_size = static_cast(10 * 1024 * 1024); // 10 MB - params.mem_buffer = nullptr; - params.no_alloc = false; - - struct ggml_context* work_ctx = ggml_init(params); - GGML_ASSERT(work_ctx != nullptr); - - { - // CPU, x{1, 3, 64, 64}: Pass - // CUDA, x{1, 3, 64, 64}: Pass, but sill get wrong result for some image, may be due to interlnal nan - // CPU, x{2, 3, 64, 64}: Wrong result - // CUDA, x{2, 3, 64, 64}: Wrong result, and different from CPU result - auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 64, 64, 3, 2); - ggml_set_f32(x, 0.5f); - print_ggml_tensor(x); - struct ggml_tensor* out = nullptr; - - int64_t t0 = ggml_time_ms(); - compute(8, x, false, &out, work_ctx); - int64_t t1 = ggml_time_ms(); - - print_ggml_tensor(out); - LOG_DEBUG("encode test done in %lldms", t1 - t0); - } - - if (false) { - // CPU, z{1, 4, 8, 8}: Pass - // CUDA, z{1, 4, 8, 8}: Pass - // CPU, z{3, 4, 8, 8}: Wrong result - // CUDA, z{3, 4, 8, 8}: Wrong result, and different from CPU result - auto z = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 4, 1); - ggml_set_f32(z, 0.5f); - print_ggml_tensor(z); - struct ggml_tensor* out = nullptr; - - int64_t t0 = ggml_time_ms(); - compute(8, z, true, &out, work_ctx); - int64_t t1 = ggml_time_ms(); - - print_ggml_tensor(out); - LOG_DEBUG("decode test done in %lldms", t1 - t0); - } - }; -}; - -#endif +#endif // __VAE_HPP__ diff --git a/src/wan.hpp b/src/wan.hpp index d94fbd4..2311955 100644 --- a/src/wan.hpp +++ b/src/wan.hpp @@ -1109,7 +1109,8 @@ namespace WAN { }; struct WanVAERunner : public VAE { - bool decode_only = true; + float scale_factor = 1.0f; + bool decode_only = true; WanVAE ae; WanVAERunner(ggml_backend_t backend, @@ -1118,7 +1119,7 @@ namespace WAN { const std::string prefix = "", bool decode_only = false, SDVersion version = VERSION_WAN2) - : decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V), VAE(backend, offload_params_to_cpu) { + : decode_only(decode_only), ae(decode_only, version == VERSION_WAN2_2_TI2V), VAE(version, backend, offload_params_to_cpu) { ae.init(params_ctx, tensor_storage_map, prefix); } @@ -1130,6 +1131,101 @@ namespace WAN { ae.get_param_tensors(tensors, prefix); } + ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr rng) { + return vae_output; + } + + void get_latents_mean_std_vec(ggml_tensor* latents, int channel_dim, std::vector& latents_mean_vec, std::vector& latents_std_vec) { + GGML_ASSERT(latents->ne[channel_dim] == 16 || latents->ne[channel_dim] == 48); + if (latents->ne[channel_dim] == 16) { // Wan2.1 VAE + latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f, + 0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f}; + latents_std_vec = {2.8184f, 1.4541f, 2.3275f, 2.6558f, 1.2196f, 1.7708f, 2.6052f, 2.0743f, + 3.2687f, 2.1526f, 2.8652f, 1.5579f, 1.6382f, 1.1253f, 2.8251f, 1.9160f}; + } else if (latents->ne[channel_dim] == 48) { // Wan2.2 VAE + latents_mean_vec = {-0.2289f, -0.0052f, -0.1323f, -0.2339f, -0.2799f, 0.0174f, 0.1838f, 0.1557f, + -0.1382f, 0.0542f, 0.2813f, 0.0891f, 0.1570f, -0.0098f, 0.0375f, -0.1825f, + -0.2246f, -0.1207f, -0.0698f, 0.5109f, 0.2665f, -0.2108f, -0.2158f, 0.2502f, + -0.2055f, -0.0322f, 0.1109f, 0.1567f, -0.0729f, 0.0899f, -0.2799f, -0.1230f, + -0.0313f, -0.1649f, 0.0117f, 0.0723f, -0.2839f, -0.2083f, -0.0520f, 0.3748f, + 0.0152f, 0.1957f, 0.1433f, -0.2944f, 0.3573f, -0.0548f, -0.1681f, -0.0667f}; + latents_std_vec = { + 0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f, + 0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f, + 0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f, + 0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f, + 0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f, + 0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f}; + } + } + + ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) { + ggml_tensor* vae_latents = ggml_dup(work_ctx, latents); + int channel_dim = sd_version_is_wan(version) ? 3 : 2; + std::vector latents_mean_vec; + std::vector latents_std_vec; + get_latents_mean_std_vec(latents, channel_dim, latents_mean_vec, latents_std_vec); + + float mean; + float std_; + for (int i = 0; i < latents->ne[3]; i++) { + if (channel_dim == 3) { + mean = latents_mean_vec[i]; + std_ = latents_std_vec[i]; + } + for (int j = 0; j < latents->ne[2]; j++) { + if (channel_dim == 2) { + mean = latents_mean_vec[j]; + std_ = latents_std_vec[j]; + } + for (int k = 0; k < latents->ne[1]; k++) { + for (int l = 0; l < latents->ne[0]; l++) { + float value = ggml_ext_tensor_get_f32(latents, l, k, j, i); + value = value * std_ / scale_factor + mean; + ggml_ext_tensor_set_f32(vae_latents, value, l, k, j, i); + } + } + } + } + + return vae_latents; + } + + ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) { + ggml_tensor* diffusion_latents = ggml_dup(work_ctx, latents); + int channel_dim = sd_version_is_wan(version) ? 3 : 2; + std::vector latents_mean_vec; + std::vector latents_std_vec; + get_latents_mean_std_vec(latents, channel_dim, latents_mean_vec, latents_std_vec); + + float mean; + float std_; + for (int i = 0; i < latents->ne[3]; i++) { + if (channel_dim == 3) { + mean = latents_mean_vec[i]; + std_ = latents_std_vec[i]; + } + for (int j = 0; j < latents->ne[2]; j++) { + if (channel_dim == 2) { + mean = latents_mean_vec[j]; + std_ = latents_std_vec[j]; + } + for (int k = 0; k < latents->ne[1]; k++) { + for (int l = 0; l < latents->ne[0]; l++) { + float value = ggml_ext_tensor_get_f32(latents, l, k, j, i); + value = (value - mean) * scale_factor / std_; + ggml_ext_tensor_set_f32(diffusion_latents, value, l, k, j, i); + } + } + } + } + return diffusion_latents; + } + + int get_encoder_output_channels(int input_channels) { + return static_cast(ae.z_dim); + } + struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) { struct ggml_cgraph* gf = new_graph_custom(10240 * z->ne[2]); @@ -1173,11 +1269,11 @@ namespace WAN { return gf; } - bool compute(const int n_threads, - struct ggml_tensor* z, - bool decode_graph, - struct ggml_tensor** output, - struct ggml_context* output_ctx = nullptr) override { + bool _compute(const int n_threads, + struct ggml_tensor* z, + bool decode_graph, + struct ggml_tensor** output, + struct ggml_context* output_ctx = nullptr) override { if (true) { auto get_graph = [&]() -> struct ggml_cgraph* { return build_graph(z, decode_graph); @@ -1249,7 +1345,7 @@ namespace WAN { struct ggml_tensor* out = nullptr; int64_t t0 = ggml_time_ms(); - compute(8, z, true, &out, work_ctx); + _compute(8, z, true, &out, work_ctx); int64_t t1 = ggml_time_ms(); print_ggml_tensor(out);