2026-04-29 01:11:15 +08:00

972 lines
46 KiB
C++

#ifndef __SD_LTX_VAE_HPP__
#define __SD_LTX_VAE_HPP__
#include <fstream>
#include <memory>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "ltxv.hpp"
#include "vae.hpp"
#include "wan.hpp"
namespace LTXVAE {
static inline ggml_tensor* apply_scale_shift(ggml_context* ctx,
ggml_tensor* x,
ggml_tensor* scale,
ggml_tensor* shift) {
x = ggml_add(ctx, x, ggml_mul(ctx, x, scale));
x = ggml_add(ctx, x, shift);
return x;
}
static inline ggml_tensor* reshape_channel_broadcast(ggml_context* ctx,
ggml_tensor* x) {
return ggml_reshape_4d(ctx, x, 1, 1, 1, ggml_nelements(x));
}
static inline std::pair<ggml_tensor*, ggml_tensor*> get_shift_scale(ggml_context* ctx,
ggml_tensor* table,
ggml_tensor* timestep,
int64_t channels,
int parts) {
GGML_ASSERT(timestep != nullptr);
GGML_ASSERT(ggml_nelements(timestep) == channels * parts);
auto timestep_view = ggml_reshape_2d(ctx, timestep, channels, parts);
auto values = ggml_add(ctx, table, timestep_view);
auto chunks = ggml_ext_chunk(ctx, values, parts, 1, false);
auto shift = reshape_channel_broadcast(ctx, ggml_cont(ctx, chunks[0]));
auto scale = reshape_channel_broadcast(ctx, ggml_cont(ctx, chunks[1]));
return {shift, scale};
}
static inline ggml_tensor* depth_to_space_3d(ggml_context* ctx,
ggml_tensor* x,
int64_t c,
int factor_t,
int factor_s,
bool drop_first_temporal_frame) {
// x: [B*c*p1*p2*p3, T, H, W], B == 1, p2 == p3 == factor_s, p1 == factor_t
// return: [B*c, T*p1, H*p2, W*p2]
// Match: rearrange(x, "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)")
const int64_t T = x->ne[2];
const int64_t H = x->ne[1];
const int64_t W = x->ne[0];
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 1, 3, 2)); // [T, C, H, W]
x = ggml_reshape_4d(ctx, x, W, H, factor_s, factor_s * factor_t * c * T); // [T*c*p1*p2, p3, H, W]
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [T*c*p1*p2, H, W, p3]
x = ggml_reshape_4d(ctx, x, factor_s * W, H, factor_s, factor_t * c * T); // [T*c*p1, p2, H, W*p3]
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [T*c*p1, H, p2, W*p3]
x = ggml_reshape_4d(ctx, x, factor_s * W * factor_s * H, factor_t, c, T); // [T, c, p1, H*p2*W*p3]
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 1, 3, 2)); // [c, T, p1, H*p2*W*p3]
x = ggml_reshape_4d(ctx, x, factor_s * W, factor_s * H, factor_t * T, c); // [T, c, T*p1, H*p2*W*p3]
if (drop_first_temporal_frame && factor_t > 1 && x->ne[2] > 0) {
x = ggml_ext_slice(ctx, x, 2, 1, x->ne[2]);
}
return x;
}
static inline ggml_tensor* patchify(ggml_context* ctx,
ggml_tensor* x,
int patch_size) {
return WAN::WanVAE::patchify(ctx, x, patch_size, 1);
}
class CausalConv3d : public GGMLBlock {
protected:
int time_kernel_size;
public:
CausalConv3d(int64_t in_channels,
int64_t out_channels,
int kernel_size = 3,
std::tuple<int, int, int> stride = {1, 1, 1},
int dilation = 1,
bool bias = true) {
time_kernel_size = kernel_size;
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv3d(in_channels,
out_channels,
{kernel_size, kernel_size, kernel_size},
stride,
{0, kernel_size / 2, kernel_size / 2},
{dilation, 1, 1},
bias));
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
bool causal = true) {
// x: [B*C, T, H, W], B == 1
auto conv = std::dynamic_pointer_cast<Conv3d>(blocks["conv"]);
if (causal) {
auto first_frame = ggml_ext_slice(ctx->ggml_ctx, x, 2, 0, 1);
auto first_frame_pad = first_frame;
for (int i = 1; i < time_kernel_size - 1; i++) {
first_frame_pad = ggml_concat(ctx->ggml_ctx, first_frame_pad, first_frame, 2);
}
x = ggml_concat(ctx->ggml_ctx, first_frame_pad, x, 2);
} else {
auto first_frame = ggml_ext_slice(ctx->ggml_ctx, x, 2, 0, 1);
auto first_frame_pad = first_frame;
for (int i = 1; i < (time_kernel_size - 1) / 2; i++) {
first_frame_pad = ggml_concat(ctx->ggml_ctx, first_frame_pad, first_frame, 2);
}
auto last_frame = ggml_ext_slice(ctx->ggml_ctx, x, 2, x->ne[2] - 1, x->ne[2]);
auto last_frame_pad = last_frame;
for (int i = 1; i < (time_kernel_size - 1) / 2; i++) {
last_frame_pad = ggml_concat(ctx->ggml_ctx, last_frame_pad, last_frame, 2);
}
x = ggml_concat(ctx->ggml_ctx, first_frame_pad, x, 2);
x = ggml_concat(ctx->ggml_ctx, x, last_frame_pad, 2);
}
return conv->forward(ctx, x);
}
};
struct PixelNorm3D : public UnaryBlock {
float eps;
PixelNorm3D(float eps = 1e-8f)
: eps(eps) {}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override {
auto h = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 3, 0, 1, 2));
h = ggml_rms_norm(ctx->ggml_ctx, h, eps);
h = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, h, 1, 2, 3, 0));
return h;
}
};
struct PixArtAlphaCombinedTimestepSizeEmbeddings : public GGMLBlock {
int64_t embedding_dim;
PixArtAlphaCombinedTimestepSizeEmbeddings(int64_t embedding_dim)
: embedding_dim(embedding_dim) {
blocks["timestep_embedder"] = std::make_shared<LTXV::TimestepEmbedder>(embedding_dim);
}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* timestep) {
auto timestep_embedder = std::dynamic_pointer_cast<LTXV::TimestepEmbedder>(blocks["timestep_embedder"]);
return timestep_embedder->forward(ctx, timestep);
}
};
struct ResnetBlock3D : public GGMLBlock {
int64_t channels;
bool timestep_conditioning;
protected:
void init_params(ggml_context* ctx,
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "") override {
if (timestep_conditioning) {
params["scale_shift_table"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, channels, 4);
}
}
public:
ResnetBlock3D(int64_t channels,
float eps = 1e-6f,
bool timestep_conditioning = false)
: channels(channels), timestep_conditioning(timestep_conditioning) {
blocks["norm1"] = std::make_shared<PixelNorm3D>(eps);
blocks["conv1"] = std::make_shared<CausalConv3d>(channels, channels, 3);
blocks["norm2"] = std::make_shared<PixelNorm3D>(eps);
blocks["conv2"] = std::make_shared<CausalConv3d>(channels, channels, 3);
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* timestep = nullptr,
bool causal = false) {
auto norm1 = std::dynamic_pointer_cast<PixelNorm3D>(blocks["norm1"]);
auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]);
auto norm2 = std::dynamic_pointer_cast<PixelNorm3D>(blocks["norm2"]);
auto conv2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv2"]);
ggml_tensor* shift1 = nullptr;
ggml_tensor* scale1 = nullptr;
ggml_tensor* shift2 = nullptr;
ggml_tensor* scale2 = nullptr;
if (timestep_conditioning) {
GGML_ASSERT(timestep != nullptr);
auto values = ggml_add(ctx->ggml_ctx,
params["scale_shift_table"],
ggml_reshape_2d(ctx->ggml_ctx, timestep, channels, 4));
auto chunks = ggml_ext_chunk(ctx->ggml_ctx, values, 4, 1, false);
shift1 = reshape_channel_broadcast(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, chunks[0]));
scale1 = reshape_channel_broadcast(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, chunks[1]));
shift2 = reshape_channel_broadcast(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, chunks[2]));
scale2 = reshape_channel_broadcast(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, chunks[3]));
}
auto h = norm1->forward(ctx, x);
if (timestep_conditioning) {
h = apply_scale_shift(ctx->ggml_ctx, h, scale1, shift1);
}
h = ggml_silu_inplace(ctx->ggml_ctx, h);
h = conv1->forward(ctx, h, causal);
h = norm2->forward(ctx, h);
if (timestep_conditioning) {
h = apply_scale_shift(ctx->ggml_ctx, h, scale2, shift2);
}
h = ggml_silu_inplace(ctx->ggml_ctx, h);
h = conv2->forward(ctx, h, causal);
return ggml_add(ctx->ggml_ctx, h, x);
}
};
struct UNetMidBlock3D : public GGMLBlock {
int64_t channels;
int num_layers;
bool timestep_conditioning;
UNetMidBlock3D(int64_t channels,
int num_layers,
bool timestep_conditioning)
: channels(channels),
num_layers(num_layers),
timestep_conditioning(timestep_conditioning) {
if (timestep_conditioning) {
blocks["time_embedder"] = std::make_shared<PixArtAlphaCombinedTimestepSizeEmbeddings>(channels * 4);
}
for (int i = 0; i < num_layers; i++) {
blocks["res_blocks." + std::to_string(i)] = std::make_shared<ResnetBlock3D>(channels, 1e-6f, timestep_conditioning);
}
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* timestep = nullptr,
bool causal = false) {
ggml_tensor* timestep_embed = nullptr;
if (timestep_conditioning) {
GGML_ASSERT(timestep != nullptr);
auto time_embedder = std::dynamic_pointer_cast<PixArtAlphaCombinedTimestepSizeEmbeddings>(blocks["time_embedder"]);
timestep_embed = time_embedder->forward(ctx, timestep);
}
for (int i = 0; i < num_layers; i++) {
auto resnet = std::dynamic_pointer_cast<ResnetBlock3D>(blocks["res_blocks." + std::to_string(i)]);
x = resnet->forward(ctx, x, timestep_embed, causal);
}
return x;
}
};
struct DepthToSpaceUpsample : public GGMLBlock {
int64_t in_channels;
int factor_t;
int factor_s;
int out_channels_reduction_factor;
bool residual;
DepthToSpaceUpsample(int64_t in_channels,
int factor_t = 2,
int factor_s = 2,
int out_channels_reduction_factor = 2,
bool residual = true)
: in_channels(in_channels),
factor_t(factor_t),
factor_s(factor_s),
out_channels_reduction_factor(out_channels_reduction_factor),
residual(residual) {
const int64_t factor = static_cast<int64_t>(factor_t) * static_cast<int64_t>(factor_s) * static_cast<int64_t>(factor_s);
const int64_t out_dim = (factor * in_channels) / out_channels_reduction_factor;
blocks["conv"] = std::make_shared<CausalConv3d>(in_channels, out_dim, 3);
}
int64_t get_output_channels() const {
return in_channels / out_channels_reduction_factor;
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
bool causal = false) {
auto conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv"]);
ggml_tensor* x_in = nullptr;
if (residual) {
x_in = depth_to_space_3d(ctx->ggml_ctx, x, in_channels / (factor_t * factor_s * factor_s), factor_t, factor_s, factor_t > 1);
int repeat = (factor_t * factor_s * factor_s) / out_channels_reduction_factor;
auto res = x_in;
for (int i = 1; i < repeat; i++) {
res = ggml_concat(ctx->ggml_ctx, res, x_in, 3);
}
x_in = res;
}
x = conv->forward(ctx, x, causal);
x = depth_to_space_3d(ctx->ggml_ctx, x, get_output_channels(), factor_t, factor_s, factor_t > 1);
if (residual) {
x = ggml_add(ctx->ggml_ctx, x, x_in);
}
return x;
}
};
struct SpaceToDepthDownsample : public GGMLBlock {
int64_t in_channels;
int64_t out_channels;
int factor_t;
int factor_s;
SpaceToDepthDownsample(int64_t in_channels,
int64_t out_channels,
int factor_t,
int factor_s)
: in_channels(in_channels),
out_channels(out_channels),
factor_t(factor_t),
factor_s(factor_s) {
const int64_t factor = static_cast<int64_t>(factor_t) * static_cast<int64_t>(factor_s) * static_cast<int64_t>(factor_s);
GGML_ASSERT(out_channels % factor == 0);
blocks["conv"] = std::make_shared<CausalConv3d>(in_channels, out_channels / factor, 3);
blocks["skip_downsample"] = std::make_shared<WAN::AvgDown3D>(in_channels, out_channels, factor_t, factor_s);
blocks["conv_downsample"] = std::make_shared<WAN::AvgDown3D>(out_channels / factor, out_channels, factor_t, factor_s);
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
bool causal = true) {
auto conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv"]);
auto skip_downsample = std::dynamic_pointer_cast<WAN::AvgDown3D>(blocks["skip_downsample"]);
auto conv_downsample = std::dynamic_pointer_cast<WAN::AvgDown3D>(blocks["conv_downsample"]);
if (factor_t > 1 && x->ne[2] > 0) {
auto first_frame = ggml_ext_slice(ctx->ggml_ctx, x, 2, 0, 1);
auto first_frame_pad = first_frame;
for (int i = 1; i < factor_t; ++i) {
first_frame_pad = ggml_concat(ctx->ggml_ctx, first_frame_pad, first_frame, 2);
}
x = ggml_concat(ctx->ggml_ctx, first_frame_pad, x, 2);
}
auto residual = skip_downsample->forward(ctx, x);
auto h = conv->forward(ctx, x, causal);
h = conv_downsample->forward(ctx, h);
return ggml_add(ctx->ggml_ctx, h, residual);
}
};
struct PerChannelStatistics : public GGMLBlock {
protected:
void init_params(ggml_context* ctx,
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "") override {
params["std-of-means"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 128);
params["mean-of-means"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 128);
}
public:
ggml_tensor* un_normalize(GGMLRunnerContext* ctx,
ggml_tensor* x) {
auto std_tensor = reshape_channel_broadcast(ctx->ggml_ctx, params["std-of-means"]);
auto mean_tensor = reshape_channel_broadcast(ctx->ggml_ctx, params["mean-of-means"]);
return ggml_add(ctx->ggml_ctx, ggml_mul(ctx->ggml_ctx, x, std_tensor), mean_tensor);
}
ggml_tensor* normalize(GGMLRunnerContext* ctx,
ggml_tensor* x) {
auto std_tensor = reshape_channel_broadcast(ctx->ggml_ctx, params["std-of-means"]);
auto mean_tensor = reshape_channel_broadcast(ctx->ggml_ctx, params["mean-of-means"]);
return ggml_div(ctx->ggml_ctx, ggml_sub(ctx->ggml_ctx, x, mean_tensor), std_tensor);
}
};
struct DecoderConfig {
struct Block {
std::string type;
int num_layers = 0;
int multiplier = 1;
};
std::vector<Block> blocks;
};
struct EncoderConfig {
struct Block {
std::string type;
int num_layers = 0;
int multiplier = 1;
};
std::vector<Block> blocks;
};
static inline bool has_tensor(const String2TensorStorage& tensor_storage_map,
const std::string& name) {
return tensor_storage_map.find(name) != tensor_storage_map.end();
}
static inline int64_t get_tensor_ne0(const String2TensorStorage& tensor_storage_map,
const std::string& name,
int64_t fallback = 0) {
auto iter = tensor_storage_map.find(name);
if (iter == tensor_storage_map.end()) {
return fallback;
}
return iter->second.ne[0];
}
static inline DecoderConfig infer_decoder_config_from_weights(const String2TensorStorage& tensor_storage_map,
const std::string& prefix,
int64_t conv_in_channels) {
DecoderConfig cfg;
const std::string decoder_prefix = prefix + ".decoder.up_blocks.";
int64_t current_channels = conv_in_channels;
for (int block_idx = 0;; ++block_idx) {
const std::string block_prefix = decoder_prefix + std::to_string(block_idx);
const std::string res0_bias = block_prefix + ".res_blocks.0.conv1.conv.bias";
const std::string conv_bias = block_prefix + ".conv.conv.bias";
if (has_tensor(tensor_storage_map, res0_bias)) {
int num_layers = 0;
while (has_tensor(tensor_storage_map,
block_prefix + ".res_blocks." + std::to_string(num_layers) + ".conv1.conv.bias")) {
num_layers++;
}
cfg.blocks.push_back({"res_x", num_layers, 1});
current_channels = get_tensor_ne0(tensor_storage_map, res0_bias, current_channels);
continue;
}
if (!has_tensor(tensor_storage_map, conv_bias)) {
break;
}
int64_t next_channels = 0;
for (int next_idx = block_idx + 1;; ++next_idx) {
const std::string next_res0_bias = decoder_prefix + std::to_string(next_idx) + ".res_blocks.0.conv1.conv.bias";
const std::string next_conv_bias = decoder_prefix + std::to_string(next_idx) + ".conv.conv.bias";
if (has_tensor(tensor_storage_map, next_res0_bias)) {
next_channels = get_tensor_ne0(tensor_storage_map, next_res0_bias);
break;
}
if (!has_tensor(tensor_storage_map, next_conv_bias)) {
break;
}
}
if (next_channels <= 0 || current_channels % next_channels != 0) {
next_channels = std::max<int64_t>(1, current_channels / 2);
}
const int64_t conv_out_dim = get_tensor_ne0(tensor_storage_map, conv_bias);
const int64_t reduction = std::max<int64_t>(1, current_channels / next_channels);
const int64_t factor = next_channels > 0 ? conv_out_dim / next_channels : 0;
if (factor == 8) {
cfg.blocks.push_back({"compress_all", 0, static_cast<int>(reduction)});
} else if (factor == 4) {
cfg.blocks.push_back({"compress_space", 0, static_cast<int>(reduction)});
} else if (factor == 2) {
cfg.blocks.push_back({"compress_time", 0, static_cast<int>(reduction)});
} else {
LOG_WARN("unexpected LTX VAE upsample factor at '%s': conv_out=%lld current=%lld next=%lld, falling back to compress_all x%d",
block_prefix.c_str(),
(long long)conv_out_dim,
(long long)current_channels,
(long long)next_channels,
(int)reduction);
cfg.blocks.push_back({"compress_all", 0, static_cast<int>(reduction)});
}
current_channels = next_channels;
}
return cfg;
}
static inline int detect_ltx_vae_version(const String2TensorStorage& tensor_storage_map,
const std::string& prefix) {
const std::string v2_probe = prefix + ".encoder.down_blocks.1.conv.conv.bias";
if (tensor_storage_map.find(v2_probe) != tensor_storage_map.end()) {
return 2;
}
return 1;
}
static inline bool detect_ltx_vae_timestep_conditioning(const String2TensorStorage& tensor_storage_map,
const std::string& prefix) {
return tensor_storage_map.find(prefix + ".decoder.timestep_scale_multiplier") != tensor_storage_map.end();
}
static inline EncoderConfig get_encoder_config(int version) {
EncoderConfig cfg;
if (version < 2) {
GGML_ABORT("LTX VAE encoder is only implemented for version >= 2");
}
cfg.blocks = {
{"res_x", 4, 1},
{"compress_space_res", 0, 2},
{"res_x", 6, 1},
{"compress_time_res", 0, 2},
{"res_x", 6, 1},
{"compress_all_res", 0, 2},
{"res_x", 2, 1},
{"compress_all_res", 0, 2},
{"res_x", 2, 1},
};
return cfg;
}
struct Encoder : public GGMLBlock {
int version;
int patch_size;
int64_t in_channels;
int64_t latent_channels;
Encoder(int version,
int patch_size = 4,
int64_t in_channels = 3,
int64_t latent_channels = 128)
: version(version),
patch_size(patch_size),
in_channels(in_channels),
latent_channels(latent_channels) {
auto cfg = get_encoder_config(version);
int64_t channels = 128;
int64_t in_dim = in_channels * patch_size * patch_size;
blocks["conv_in"] = std::make_shared<CausalConv3d>(in_dim, channels, 3);
for (int block_idx = 0; block_idx < static_cast<int>(cfg.blocks.size()); ++block_idx) {
const auto& block = cfg.blocks[block_idx];
if (block.type == "res_x") {
blocks["down_blocks." + std::to_string(block_idx)] = std::make_shared<UNetMidBlock3D>(channels,
block.num_layers,
false);
} else if (block.type == "compress_space_res") {
int64_t next_channels = channels * block.multiplier;
blocks["down_blocks." + std::to_string(block_idx)] = std::make_shared<SpaceToDepthDownsample>(channels,
next_channels,
1,
2);
channels = next_channels;
} else if (block.type == "compress_time_res") {
int64_t next_channels = channels * block.multiplier;
blocks["down_blocks." + std::to_string(block_idx)] = std::make_shared<SpaceToDepthDownsample>(channels,
next_channels,
2,
1);
channels = next_channels;
} else if (block.type == "compress_all_res") {
int64_t next_channels = channels * block.multiplier;
blocks["down_blocks." + std::to_string(block_idx)] = std::make_shared<SpaceToDepthDownsample>(channels,
next_channels,
2,
2);
channels = next_channels;
} else {
GGML_ABORT("Unsupported LTX VAE encoder block");
}
}
blocks["conv_norm_out"] = std::make_shared<PixelNorm3D>();
blocks["conv_out"] = std::make_shared<CausalConv3d>(channels, latent_channels + 1, 3);
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x) {
auto conv_in = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv_in"]);
auto conv_norm_out = std::dynamic_pointer_cast<PixelNorm3D>(blocks["conv_norm_out"]);
auto conv_out = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv_out"]);
x = conv_in->forward(ctx, x, true);
int block_idx = 0;
while (blocks.find("down_blocks." + std::to_string(block_idx)) != blocks.end()) {
auto mid_block = std::dynamic_pointer_cast<UNetMidBlock3D>(blocks["down_blocks." + std::to_string(block_idx)]);
if (mid_block) {
x = mid_block->forward(ctx, x, nullptr, true);
} else {
auto downsample = std::dynamic_pointer_cast<SpaceToDepthDownsample>(blocks["down_blocks." + std::to_string(block_idx)]);
x = downsample->forward(ctx, x, true);
}
block_idx++;
}
x = conv_norm_out->forward(ctx, x);
x = ggml_silu_inplace(ctx->ggml_ctx, x);
x = conv_out->forward(ctx, x, true);
auto last_channel = ggml_ext_slice(ctx->ggml_ctx, x, 3, x->ne[3] - 1, x->ne[3]);
auto repeat_shape = ggml_new_tensor_4d(ctx->ggml_ctx, last_channel->type, last_channel->ne[0], last_channel->ne[1], last_channel->ne[2], latent_channels - 1);
auto repeated = ggml_repeat(ctx->ggml_ctx, last_channel, repeat_shape);
return ggml_concat(ctx->ggml_ctx, x, repeated, 3);
}
};
struct Decoder : public GGMLBlock {
int version;
int patch_size;
bool causal_decoder;
bool timestep_conditioning;
int64_t in_channels;
int64_t hidden_channels;
protected:
void init_params(ggml_context* ctx,
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "") override {
if (timestep_conditioning) {
params["timestep_scale_multiplier"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
params["last_scale_shift_table"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hidden_channels, 2);
}
}
public:
Decoder(int version,
const String2TensorStorage& tensor_storage_map,
const std::string& prefix,
int patch_size = 4,
bool causal_decoder = false,
bool timestep_conditioning = true,
int64_t in_channels = 128,
int64_t hidden_channels = 128)
: version(version),
patch_size(patch_size),
causal_decoder(causal_decoder),
timestep_conditioning(timestep_conditioning),
in_channels(in_channels),
hidden_channels(hidden_channels) {
const int64_t conv_in_out_channels = get_tensor_ne0(tensor_storage_map,
prefix + ".decoder.conv_in.conv.bias",
hidden_channels);
auto cfg = infer_decoder_config_from_weights(tensor_storage_map,
prefix,
conv_in_out_channels);
int64_t channels = conv_in_out_channels;
blocks["conv_in"] = std::make_shared<CausalConv3d>(in_channels, channels, 3);
for (int block_idx = 0; block_idx < static_cast<int>(cfg.blocks.size()); ++block_idx) {
const auto& block = cfg.blocks[block_idx];
if (block.type == "res_x") {
blocks["up_blocks." + std::to_string(block_idx)] = std::make_shared<UNetMidBlock3D>(channels,
block.num_layers,
timestep_conditioning);
} else if (block.type == "compress_all") {
blocks["up_blocks." + std::to_string(block_idx)] = std::make_shared<DepthToSpaceUpsample>(channels,
2,
2,
block.multiplier,
false);
channels /= block.multiplier;
} else if (block.type == "compress_time") {
blocks["up_blocks." + std::to_string(block_idx)] = std::make_shared<DepthToSpaceUpsample>(channels,
2,
1,
block.multiplier,
false);
channels /= block.multiplier;
} else if (block.type == "compress_space") {
blocks["up_blocks." + std::to_string(block_idx)] = std::make_shared<DepthToSpaceUpsample>(channels,
1,
2,
block.multiplier,
false);
channels /= block.multiplier;
} else {
GGML_ABORT("Unsupported LTX VAE decoder block");
}
}
hidden_channels = channels;
blocks["conv_norm_out"] = std::make_shared<PixelNorm3D>();
blocks["conv_out"] = std::make_shared<CausalConv3d>(hidden_channels, 3 * patch_size * patch_size, 3);
if (timestep_conditioning) {
blocks["last_time_embedder"] = std::make_shared<PixArtAlphaCombinedTimestepSizeEmbeddings>(hidden_channels * 2);
}
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* timestep) {
auto conv_in = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv_in"]);
auto conv_norm_out = std::dynamic_pointer_cast<PixelNorm3D>(blocks["conv_norm_out"]);
auto conv_out = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv_out"]);
ggml_tensor* scaled_timestep = timestep;
if (timestep_conditioning) {
auto multiplier = ggml_ext_backend_tensor_get_f32(params["timestep_scale_multiplier"]);
scaled_timestep = ggml_ext_scale(ctx->ggml_ctx, timestep, multiplier);
}
x = conv_in->forward(ctx, x, causal_decoder);
int block_idx = 0;
while (blocks.find("up_blocks." + std::to_string(block_idx)) != blocks.end()) {
auto mid_block = std::dynamic_pointer_cast<UNetMidBlock3D>(blocks["up_blocks." + std::to_string(block_idx)]);
if (mid_block) {
x = mid_block->forward(ctx, x, scaled_timestep, causal_decoder);
} else {
auto upsample = std::dynamic_pointer_cast<DepthToSpaceUpsample>(blocks["up_blocks." + std::to_string(block_idx)]);
x = upsample->forward(ctx, x, causal_decoder);
}
block_idx++;
}
x = conv_norm_out->forward(ctx, x);
if (timestep_conditioning) {
auto last_time_embedder = std::dynamic_pointer_cast<PixArtAlphaCombinedTimestepSizeEmbeddings>(blocks["last_time_embedder"]);
auto timestep_embed = last_time_embedder->forward(ctx, scaled_timestep);
auto [shift, scale] = get_shift_scale(ctx->ggml_ctx,
params["last_scale_shift_table"],
timestep_embed,
hidden_channels,
2);
x = apply_scale_shift(ctx->ggml_ctx, x, scale, shift);
}
x = ggml_silu_inplace(ctx->ggml_ctx, x);
x = conv_out->forward(ctx, x, causal_decoder);
return x;
}
};
struct VideoVAE : public GGMLBlock {
int version;
float decode_timestep;
bool timestep_conditioning;
int patch_size;
bool decode_only;
VideoVAE(int version,
bool decode_only,
bool timestep_conditioning,
int patch_size,
const String2TensorStorage& tensor_storage_map,
const std::string& prefix,
float decode_timestep = 0.05f)
: version(version),
decode_timestep(decode_timestep),
timestep_conditioning(timestep_conditioning),
patch_size(patch_size),
decode_only(decode_only) {
if (!decode_only) {
blocks["encoder"] = std::make_shared<Encoder>(version, patch_size);
}
blocks["decoder"] = std::make_shared<Decoder>(version,
tensor_storage_map,
prefix,
patch_size,
false,
timestep_conditioning);
blocks["per_channel_statistics"] = std::make_shared<PerChannelStatistics>();
}
ggml_tensor* decode(GGMLRunnerContext* ctx,
ggml_tensor* z,
ggml_tensor* timestep) {
auto decoder = std::dynamic_pointer_cast<Decoder>(blocks["decoder"]);
auto processor = std::dynamic_pointer_cast<PerChannelStatistics>(blocks["per_channel_statistics"]);
auto latents = processor->un_normalize(ctx, z);
auto out = decoder->forward(ctx, latents, timestep);
out = WAN::WanVAE::unpatchify(ctx->ggml_ctx, out, patch_size, 1);
return out;
}
ggml_tensor* encode(GGMLRunnerContext* ctx,
ggml_tensor* x) {
GGML_ASSERT(!decode_only);
auto encoder = std::dynamic_pointer_cast<Encoder>(blocks["encoder"]);
auto processor = std::dynamic_pointer_cast<PerChannelStatistics>(blocks["per_channel_statistics"]);
x = patchify(ctx->ggml_ctx, x, patch_size);
auto out = encoder->forward(ctx, x);
auto mean = ggml_ext_chunk(ctx->ggml_ctx, out, 2, 3, false)[0];
mean = ggml_cont(ctx->ggml_ctx, mean);
return processor->normalize(ctx, mean);
}
};
} // namespace LTXVAE
struct LTXVideoVAE : public VAE {
bool decode_only;
int ltx_vae_version;
bool timestep_conditioning;
int patch_size;
sd::Tensor<float> decode_timestep_tensor;
LTXVAE::VideoVAE vae;
LTXVideoVAE(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map,
const std::string& prefix,
bool decode_only = true,
SDVersion version = VERSION_LTXAV)
: decode_only(decode_only),
ltx_vae_version(LTXVAE::detect_ltx_vae_version(tensor_storage_map, prefix)),
timestep_conditioning(LTXVAE::detect_ltx_vae_timestep_conditioning(tensor_storage_map, prefix)),
patch_size(4),
decode_timestep_tensor(sd::Tensor<float>::from_vector({0.05f})),
vae(LTXVAE::detect_ltx_vae_version(tensor_storage_map, prefix),
decode_only,
LTXVAE::detect_ltx_vae_timestep_conditioning(tensor_storage_map, prefix),
patch_size,
tensor_storage_map,
prefix),
VAE(version, backend, offload_params_to_cpu) {
vae.init(params_ctx, tensor_storage_map, prefix);
decode_timestep_tensor.values()[0] = vae.decode_timestep;
}
std::string get_desc() override {
return "ltx_video_vae";
}
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) override {
vae.get_param_tensors(tensors, prefix);
}
ggml_cgraph* build_graph(const sd::Tensor<float>& z_tensor, bool decode_graph) {
LOG_DEBUG("ltx_video_vae build_graph input %dx%dx%dx%d",
(int)z_tensor.shape()[0],
(int)z_tensor.shape()[1],
(int)z_tensor.shape()[2],
(int)z_tensor.shape()[3]);
ggml_cgraph* gf = ggml_new_graph(compute_ctx);
ggml_tensor* z = make_input(z_tensor);
ggml_tensor* timestep = nullptr;
if (timestep_conditioning) {
timestep = make_input(decode_timestep_tensor);
}
auto runner_ctx = get_context();
ggml_tensor* out = decode_graph ? vae.decode(&runner_ctx, z, timestep) : vae.encode(&runner_ctx, z);
LOG_DEBUG("ltx_video_vae build_graph output ne=[%lld,%lld,%lld,%lld]",
(long long)out->ne[0],
(long long)out->ne[1],
(long long)out->ne[2],
(long long)out->ne[3]);
ggml_build_forward_expand(gf, out);
return gf;
}
sd::Tensor<float> _compute(const int n_threads,
const sd::Tensor<float>& z,
bool decode_graph) override {
if (!decode_graph && decode_only) {
LOG_ERROR("LTX video VAE encoder is not implemented yet");
return {};
}
sd::Tensor<float> input = z;
size_t expected_dim = static_cast<size_t>(z.dim());
if (!decode_graph) {
if (input.dim() == 4) {
input = input.unsqueeze(2);
expected_dim = 5;
} else if (input.dim() != 5) {
LOG_ERROR("LTX video VAE encoder expects 4D image or 5D video input, got dim=%lld",
(long long)input.dim());
return {};
}
int64_t cropped_t = std::max<int64_t>(1, 1 + ((input.shape()[2] - 1) / 8) * 8);
if (cropped_t != input.shape()[2]) {
input = sd::ops::slice(input, 2, 0, cropped_t);
}
}
auto get_graph = [&]() -> ggml_cgraph* {
return build_graph(input, decode_graph);
};
auto result = restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), expected_dim);
if (result.empty()) {
return {};
}
LOG_DEBUG("ltx_video_vae host output shape=[%lld,%lld,%lld,%lld] dim=%lld",
(long long)(result.shape().size() > 0 ? result.shape()[0] : 0),
(long long)(result.shape().size() > 1 ? result.shape()[1] : 0),
(long long)(result.shape().size() > 2 ? result.shape()[2] : 0),
(long long)(result.shape().size() > 3 ? result.shape()[3] : 0),
(long long)result.dim());
return result;
}
int get_encoder_output_channels(int input_channels) override {
SD_UNUSED(input_channels);
return 256;
}
sd::Tensor<float> vae_output_to_latents(const sd::Tensor<float>& vae_output, std::shared_ptr<RNG> rng) override {
SD_UNUSED(rng);
if (vae_output.dim() >= 4 && vae_output.shape()[3] > 128) {
return sd::ops::slice(vae_output, 3, 0, 128);
}
return vae_output;
}
sd::Tensor<float> diffusion_to_vae_latents(const sd::Tensor<float>& latents) override {
return latents;
}
sd::Tensor<float> vae_to_diffusion_latents(const sd::Tensor<float>& latents) override {
return latents;
}
void test(const std::string& input_path) {
auto z = sd::load_tensor_from_file_as_tensor<float>(input_path);
print_sd_tensor(z, false, "ltx_vae_z");
z = diffusion_to_vae_latents(z);
int64_t t0 = ggml_time_ms();
auto out = _compute(8, z, true);
int64_t t1 = ggml_time_ms();
GGML_ASSERT(!out.empty());
print_sd_tensor(out, false, "ltx_vae_out");
LOG_DEBUG("ltx vae test done in %lldms", t1 - t0);
}
static void load_from_file_and_test(const std::string& model_path,
const std::string& input_path) {
// ggml_backend_t backend = ggml_backend_cuda_init(0);
ggml_backend_t backend = ggml_backend_cpu_init();
LOG_INFO("loading ltx vae from '%s'", model_path.c_str());
ModelLoader model_loader;
if (!model_loader.init_from_file_and_convert_name(model_path, "vae.")) {
LOG_ERROR("init model loader from file failed: '%s'", model_path.c_str());
return;
}
auto& tensor_storage_map = model_loader.get_tensor_storage_map();
std::shared_ptr<LTXVideoVAE> vae = std::make_shared<LTXVideoVAE>(backend,
false,
tensor_storage_map,
"first_stage_model",
true,
VERSION_LTXAV);
vae->alloc_params_buffer();
std::map<std::string, ggml_tensor*> tensors;
vae->get_param_tensors(tensors, "first_stage_model");
if (!model_loader.load_tensors(tensors)) {
LOG_ERROR("load tensors from model loader failed");
return;
}
LOG_INFO("ltx vae model loaded");
vae->test(input_path);
}
};
#endif // __SD_LTX_VAE_HPP__