This commit is contained in:
leejet 2026-04-27 21:43:22 +08:00
parent 274ecd5d41
commit ca7e008d78
19 changed files with 3415 additions and 141 deletions

View File

@ -20,6 +20,7 @@
#include "common/resource_owners.hpp"
#include "image_metadata.h"
#include "llm.hpp"
#include "ltx_vae_test.h"
namespace fs = std::filesystem;
@ -503,12 +504,24 @@ int main(int argc, const char* argv[]) {
cli_params.verbose = true;
sd_set_log_callback(sd_log_cb, (void*)&cli_params);
GemmaTokenizer tokenizer;
auto tokens = tokenizer.tokenize("<html> 一只可爱的小猫");
for (auto token : tokens) {
LOG_INFO("%d", token);
}
{
const bool run_ltx_vae_test = false;
const std::string model_path = "E:/Code/ComfyUI/models/vae/ltx-2.3-22b-dev_video_vae.safetensors";
const std::string input_path = "E:/Code/sd.cpp/build/ltx_vae_z.bin";
if (run_ltx_vae_test) {
ltx_vae_load_from_file_and_test(model_path, input_path);
return 0;
}
}
// cli_params.verbose = true;
// sd_set_log_callback(sd_log_cb, (void*)&cli_params);
// GemmaTokenizer tokenizer;
// auto tokens = tokenizer.tokenize("<html> 一只可爱的小猫");
// for (auto token : tokens) {
// LOG_INFO("%d", token);
// }
// return 0;
parse_args(argc, argv, cli_params, ctx_params, gen_params);
sd_set_log_callback(sd_log_cb, (void*)&cli_params);

View File

@ -340,6 +340,10 @@ ArgOptions SDContextParams::get_options() {
"--high-noise-diffusion-model",
"path to the standalone high noise diffusion model",
&high_noise_diffusion_model_path},
{"",
"--embeddings-connectors",
"path to LTXAV embeddings connectors",
&embeddings_connectors_path},
{"",
"--vae",
"path to standalone vae model",
@ -656,6 +660,7 @@ std::string SDContextParams::to_string() const {
<< " llm_vision_path: \"" << llm_vision_path << "\",\n"
<< " diffusion_model_path: \"" << diffusion_model_path << "\",\n"
<< " high_noise_diffusion_model_path: \"" << high_noise_diffusion_model_path << "\",\n"
<< " embeddings_connectors_path: \"" << embeddings_connectors_path << "\",\n"
<< " vae_path: \"" << vae_path << "\",\n"
<< " taesd_path: \"" << taesd_path << "\",\n"
<< " esrgan_path: \"" << esrgan_path << "\",\n"
@ -712,6 +717,7 @@ sd_ctx_params_t SDContextParams::to_sd_ctx_params_t(bool vae_decode_only, bool f
llm_vision_path.c_str(),
diffusion_model_path.c_str(),
high_noise_diffusion_model_path.c_str(),
embeddings_connectors_path.c_str(),
vae_path.c_str(),
taesd_path.c_str(),
control_net_path.c_str(),
@ -2180,6 +2186,7 @@ sd_vid_gen_params_t SDGenerationParams::to_sd_vid_gen_params_t() {
params.strength = strength;
params.seed = seed;
params.video_frames = video_frames;
params.fps = fps;
params.vace_strength = vace_strength;
params.vae_tiling_params = vae_tiling_params;
params.cache = cache_params;

View File

@ -92,6 +92,7 @@ struct SDContextParams {
std::string llm_vision_path;
std::string diffusion_model_path;
std::string high_noise_diffusion_model_path;
std::string embeddings_connectors_path;
std::string vae_path;
std::string taesd_path;
std::string esrgan_path;

View File

@ -171,6 +171,7 @@ typedef struct {
const char* llm_vision_path;
const char* diffusion_model_path;
const char* high_noise_diffusion_model_path;
const char* embeddings_connectors_path;
const char* vae_path;
const char* taesd_path;
const char* control_net_path;
@ -359,6 +360,7 @@ typedef struct {
float strength;
int64_t seed;
int video_frames;
int fps;
float vace_strength;
sd_tiling_params_t vae_tiling_params;
sd_cache_params_t cache;

View File

@ -1,6 +1,8 @@
#ifndef __CONDITIONER_HPP__
#define __CONDITIONER_HPP__
#include <cmath>
#include <limits>
#include <optional>
#include "clip.hpp"
@ -46,6 +48,17 @@ static inline sd::Tensor<float> apply_token_weights(sd::Tensor<float> hidden_sta
return hidden_states;
}
bool all_one = true;
for (float weight : weights) {
if (weight != 1.0f) {
all_one = false;
break;
}
}
if (all_one) {
return hidden_states;
}
if (hidden_states.dim() == 1) {
hidden_states.unsqueeze_(1);
}
@ -57,7 +70,7 @@ static inline sd::Tensor<float> apply_token_weights(sd::Tensor<float> hidden_sta
chunk_weights.reshape_({1, static_cast<int64_t>(weights.size())});
hidden_states *= chunk_weights;
float new_mean = hidden_states.mean();
if (new_mean != 0.0f) {
if (std::isfinite(original_mean) && std::isfinite(new_mean) && new_mean != 0.0f) {
hidden_states *= (original_mean / new_mean);
}
@ -1958,4 +1971,277 @@ struct LLMEmbedder : public Conditioner {
}
};
struct LTXAVTextProjection : public GGMLBlock {
static constexpr int64_t kHiddenSize = 3840;
static constexpr int64_t kNumStates = 49;
bool dual_projection = false;
LTXAVTextProjection(bool dual_projection = false)
: dual_projection(dual_projection) {
if (dual_projection) {
blocks["video_aggregate_embed"] = std::make_shared<Linear>(kHiddenSize * kNumStates, 4096, true);
blocks["audio_aggregate_embed"] = std::make_shared<Linear>(kHiddenSize * kNumStates, 2048, true);
} else {
blocks["projection"] = std::make_shared<Linear>(kHiddenSize * kNumStates, kHiddenSize, false);
}
}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
if (!dual_projection) {
auto projection = std::dynamic_pointer_cast<Linear>(blocks["projection"]);
return projection->forward(ctx, x);
}
auto video_projection = std::dynamic_pointer_cast<Linear>(blocks["video_aggregate_embed"]);
auto audio_projection = std::dynamic_pointer_cast<Linear>(blocks["audio_aggregate_embed"]);
auto video_in = ggml_ext_scale(ctx->ggml_ctx, x, std::sqrt(4096.f / static_cast<float>(kHiddenSize)));
auto audio_in = ggml_ext_scale(ctx->ggml_ctx, x, std::sqrt(2048.f / static_cast<float>(kHiddenSize)));
auto video = video_projection->forward(ctx, video_in);
auto audio = audio_projection->forward(ctx, audio_in);
return ggml_concat(ctx->ggml_ctx, video, audio, 0);
}
};
struct LTXAVTextProjectionRunner : public GGMLRunner {
LTXAVTextProjection model;
LTXAVTextProjectionRunner(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map = {},
const std::string& prefix = "")
: GGMLRunner(backend, offload_params_to_cpu),
model(tensor_storage_map.find(prefix + ".video_aggregate_embed.weight") != tensor_storage_map.end()) {
model.init(params_ctx, tensor_storage_map, prefix);
}
std::string get_desc() override {
return "ltxav_text_projection";
}
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string& prefix) {
model.get_param_tensors(tensors, prefix);
}
ggml_cgraph* build_graph(const sd::Tensor<float>& x_tensor) {
ggml_cgraph* gf = ggml_new_graph(compute_ctx);
auto x = make_input(x_tensor);
auto runner_ctx = get_context();
auto out = model.forward(&runner_ctx, x);
ggml_build_forward_expand(gf, out);
return gf;
}
sd::Tensor<float> compute(int n_threads, const sd::Tensor<float>& x) {
auto get_graph = [&]() -> ggml_cgraph* {
return build_graph(x);
};
return take_or_empty(GGMLRunner::compute<float>(get_graph, n_threads, true));
}
};
struct LTXAVEmbedder : public Conditioner {
static constexpr int64_t kHiddenSize = 3840;
static constexpr int64_t kNumStates = 49;
static constexpr int64_t kMinLength = 1024;
std::shared_ptr<GemmaTokenizer> tokenizer;
std::shared_ptr<LLM::LLMRunner> llm;
std::shared_ptr<LTXAVTextProjectionRunner> projector;
bool dual_projection = false;
LTXAVEmbedder(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map = {},
const std::string& llm_prefix = "text_encoders.llm",
const std::string& projector_prefix = "text_embedding_projection") {
tokenizer = std::make_shared<GemmaTokenizer>();
llm = std::make_shared<LLM::LLMRunner>(LLM::LLMArch::GEMMA3_12B,
backend,
offload_params_to_cpu,
tensor_storage_map,
llm_prefix,
false);
dual_projection = tensor_storage_map.find(projector_prefix + ".video_aggregate_embed.weight") != tensor_storage_map.end();
projector = std::make_shared<LTXAVTextProjectionRunner>(backend,
offload_params_to_cpu,
tensor_storage_map,
projector_prefix);
}
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) override {
llm->get_param_tensors(tensors, "text_encoders.llm");
projector->get_param_tensors(tensors, "text_embedding_projection");
}
void alloc_params_buffer() override {
llm->alloc_params_buffer();
projector->alloc_params_buffer();
}
void free_params_buffer() override {
llm->free_params_buffer();
projector->free_params_buffer();
}
size_t get_params_buffer_size() override {
return llm->get_params_buffer_size() + projector->get_params_buffer_size();
}
void set_flash_attention_enabled(bool enabled) override {
llm->set_flash_attention_enabled(enabled);
projector->set_flash_attention_enabled(enabled);
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
llm->set_weight_adapter(adapter);
projector->set_weight_adapter(adapter);
}
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> tokenize(std::string text,
const std::pair<int, int>& attn_range) {
std::vector<std::pair<std::string, float>> parsed_attention;
if (attn_range.first >= 0 && attn_range.second > 0) {
if (attn_range.first > 0) {
parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f);
}
if (attn_range.second - attn_range.first > 0) {
auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first));
parsed_attention.insert(parsed_attention.end(), new_parsed_attention.begin(), new_parsed_attention.end());
}
if (static_cast<size_t>(attn_range.second) < text.size()) {
parsed_attention.emplace_back(text.substr(attn_range.second), 1.f);
}
} else {
parsed_attention.emplace_back(text, 1.f);
}
std::vector<int> tokens;
std::vector<float> weights;
for (const auto& item : parsed_attention) {
auto curr_tokens = tokenizer->encode(item.first, nullptr);
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
weights.insert(weights.end(), curr_tokens.size(), item.second);
}
std::vector<float> mask;
tokenizer->pad_tokens(tokens, &weights, &mask, kMinLength);
return {tokens, weights, mask};
}
sd::Tensor<float> encode_prompt(int n_threads,
const std::string& prompt,
const std::pair<int, int>& prompt_attn_range) {
auto tokens_weights_mask = tokenize(prompt, prompt_attn_range);
auto& tokens = std::get<0>(tokens_weights_mask);
auto& weights = std::get<1>(tokens_weights_mask);
auto& mask = std::get<2>(tokens_weights_mask);
sd::Tensor<int32_t> input_ids({static_cast<int64_t>(tokens.size())}, std::vector<int32_t>(tokens.begin(), tokens.end()));
sd::Tensor<float> attention_mask;
if (!mask.empty()) {
const float mask_min = std::numeric_limits<float>::lowest() / 4.0f;
attention_mask = sd::Tensor<float>({static_cast<int64_t>(mask.size()), static_cast<int64_t>(mask.size())});
for (size_t i1 = 0; i1 < mask.size(); ++i1) {
for (size_t i0 = 0; i0 < mask.size(); ++i0) {
float value = 0.0f;
if (mask[i0] == 0.0f) {
value += mask_min;
}
if (i0 > i1) {
value += mask_min;
}
attention_mask[static_cast<int64_t>(i0 + mask.size() * i1)] = value;
}
}
}
auto hidden_states = llm->compute(n_threads,
input_ids,
attention_mask,
{},
{},
true);
GGML_ASSERT(!hidden_states.empty());
hidden_states = apply_token_weights(std::move(hidden_states), weights);
int64_t valid_tokens = 0;
for (float value : mask) {
valid_tokens += static_cast<int64_t>(value > 0.0f);
}
GGML_ASSERT(valid_tokens > 0);
hidden_states = sd::ops::slice(hidden_states,
1,
hidden_states.shape()[1] - valid_tokens,
hidden_states.shape()[1]);
hidden_states.reshape_({kHiddenSize, kNumStates, valid_tokens});
hidden_states = hidden_states.permute({1, 0, 2});
if (dual_projection) {
for (int64_t state_idx = 0; state_idx < kNumStates; ++state_idx) {
for (int64_t token_idx = 0; token_idx < valid_tokens; ++token_idx) {
double sq_sum = 0.0;
for (int64_t hidden_idx = 0; hidden_idx < kHiddenSize; ++hidden_idx) {
float value = hidden_states.index(state_idx, hidden_idx, token_idx);
sq_sum += static_cast<double>(value) * static_cast<double>(value);
}
float inv_rms = 1.0f / std::sqrt(static_cast<float>(sq_sum / static_cast<double>(kHiddenSize)) + 1e-6f);
for (int64_t hidden_idx = 0; hidden_idx < kHiddenSize; ++hidden_idx) {
hidden_states.index(state_idx, hidden_idx, token_idx) *= inv_rms;
}
}
}
} else {
for (int64_t state_idx = 0; state_idx < kNumStates; ++state_idx) {
double sum = 0.0;
float min_value = std::numeric_limits<float>::infinity();
float max_value = -std::numeric_limits<float>::infinity();
for (int64_t token_idx = 0; token_idx < valid_tokens; ++token_idx) {
for (int64_t hidden_idx = 0; hidden_idx < kHiddenSize; ++hidden_idx) {
float value = hidden_states.index(state_idx, hidden_idx, token_idx);
sum += value;
min_value = std::min(min_value, value);
max_value = std::max(max_value, value);
}
}
float mean_value = static_cast<float>(sum / static_cast<double>(kHiddenSize * valid_tokens));
float denom = max_value - min_value + 1e-6f;
float scale_value = 8.0f / denom;
for (int64_t token_idx = 0; token_idx < valid_tokens; ++token_idx) {
for (int64_t hidden_idx = 0; hidden_idx < kHiddenSize; ++hidden_idx) {
float value = hidden_states.index(state_idx, hidden_idx, token_idx);
hidden_states.index(state_idx, hidden_idx, token_idx) = (value - mean_value) * scale_value;
}
}
}
}
hidden_states.reshape_({kNumStates * kHiddenSize, valid_tokens});
return projector->compute(n_threads, hidden_states);
}
SDCondition get_learned_condition(int n_threads,
const ConditionerParams& conditioner_params) override {
int64_t t0 = ggml_time_ms();
std::string prompt;
std::pair<int, int> prompt_attn_range;
prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());
auto hidden_states = encode_prompt(n_threads, prompt, prompt_attn_range);
GGML_ASSERT(!hidden_states.empty());
int64_t t1 = ggml_time_ms();
LOG_DEBUG("computing LTXAV condition graph completed, taking %" PRId64 " ms", t1 - t0);
SDCondition result;
result.c_crossattn = std::move(hidden_states);
return result;
}
};
#endif

View File

@ -5,6 +5,7 @@
#include "anima.hpp"
#include "ernie_image.hpp"
#include "flux.hpp"
#include "ltxv.hpp"
#include "mmdit.hpp"
#include "qwen_image.hpp"
#include "tensor_ggml.hpp"
@ -14,7 +15,9 @@
struct DiffusionParams {
const sd::Tensor<float>* x = nullptr;
const sd::Tensor<float>* audio_x = nullptr;
const sd::Tensor<float>* timesteps = nullptr;
const sd::Tensor<float>* audio_timesteps = nullptr;
const sd::Tensor<float>* context = nullptr;
const sd::Tensor<float>* c_concat = nullptr;
const sd::Tensor<float>* y = nullptr;
@ -28,6 +31,7 @@ struct DiffusionParams {
float control_strength = 0.f;
const sd::Tensor<float>* vace_context = nullptr;
float vace_strength = 1.f;
int audio_length = 0;
const std::vector<int>* skip_layers = nullptr;
};
@ -579,4 +583,69 @@ struct ErnieImageModel : public DiffusionModel {
}
};
struct LTXAVModel : public DiffusionModel {
std::string prefix;
LTXV::LTXAVRunner ltxav;
LTXAVModel(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "model.diffusion_model")
: prefix(prefix), ltxav(backend, offload_params_to_cpu, tensor_storage_map, prefix) {
}
std::string get_desc() override {
return ltxav.get_desc();
}
void alloc_params_buffer() override {
ltxav.alloc_params_buffer();
}
void free_params_buffer() override {
ltxav.free_params_buffer();
}
void free_compute_buffer() override {
ltxav.free_compute_buffer();
}
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) override {
ltxav.get_param_tensors(tensors, prefix);
}
size_t get_params_buffer_size() override {
return ltxav.get_params_buffer_size();
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
ltxav.set_weight_adapter(adapter);
}
int64_t get_adm_in_channels() override {
return 0;
}
void set_flash_attention_enabled(bool enabled) override {
ltxav.set_flash_attention_enabled(enabled);
}
void set_circular_axes(bool circular_x, bool circular_y) override {
ltxav.set_circular_axes(circular_x, circular_y);
}
sd::Tensor<float> compute(int n_threads,
const DiffusionParams& diffusion_params) override {
GGML_ASSERT(diffusion_params.x != nullptr);
GGML_ASSERT(diffusion_params.timesteps != nullptr);
return ltxav.compute(n_threads,
*diffusion_params.x,
*diffusion_params.timesteps,
tensor_or_empty(diffusion_params.context),
tensor_or_empty(diffusion_params.audio_x),
tensor_or_empty(diffusion_params.audio_timesteps),
diffusion_params.audio_length);
}
};
#endif

View File

@ -2,8 +2,10 @@
#define __LLM_HPP__
#include <algorithm>
#include <cmath>
#include <fstream>
#include <iostream>
#include <limits>
#include <map>
#include <memory>
#include <optional>
@ -30,6 +32,7 @@ namespace LLM {
QWEN3,
MISTRAL_SMALL_3_2,
MINISTRAL_3_3B,
GEMMA3_12B,
ARCH_COUNT,
};
@ -38,6 +41,12 @@ namespace LLM {
"qwen3",
"mistral_small3.2",
"ministral3.3b",
"gemma3_12b",
};
enum class MLPActivation {
SILU,
GELU_TANH,
};
struct LLMVisionParams {
@ -64,14 +73,62 @@ namespace LLM {
int head_dim = 128;
bool qkv_bias = true;
bool qk_norm = false;
bool rms_norm_add = false;
bool normalize_input = false;
int64_t vocab_size = 152064;
int64_t max_position_embeddings = 128000;
float rms_norm_eps = 1e-06f;
MLPActivation mlp_activation = MLPActivation::SILU;
std::vector<float> rope_thetas = {1000000.f};
std::vector<float> rope_scales = {1.f};
std::vector<int> sliding_attention;
LLMVisionParams vision;
};
struct MLP : public GGMLBlock {
struct LLMRMSNorm : public UnaryBlock {
protected:
int64_t hidden_size;
float eps;
bool add_unit_offset;
std::string prefix;
void init_params(ggml_context* ctx,
const String2TensorStorage& tensor_storage_map = {},
std::string prefix = "") override {
this->prefix = prefix;
params["weight"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size);
}
public:
MLP(int64_t hidden_size, int64_t intermediate_size, bool bias = false) {
LLMRMSNorm(int64_t hidden_size,
float eps = 1e-06f,
bool add_unit_offset = false)
: hidden_size(hidden_size), eps(eps), add_unit_offset(add_unit_offset) {}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
ggml_tensor* w = params["weight"];
if (ctx->weight_adapter) {
w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, w, prefix + "weight");
}
x = ggml_rms_norm(ctx->ggml_ctx, x, eps);
auto scaled = ggml_mul(ctx->ggml_ctx, x, w);
if (add_unit_offset) {
scaled = ggml_add_inplace(ctx->ggml_ctx, scaled, x);
}
return scaled;
}
};
struct MLP : public GGMLBlock {
protected:
MLPActivation activation;
public:
MLP(int64_t hidden_size,
int64_t intermediate_size,
bool bias = false,
MLPActivation activation_ = MLPActivation::SILU)
: activation(activation_) {
blocks["gate_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, intermediate_size, bias));
blocks["up_proj"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, intermediate_size, bias));
blocks["down_proj"] = std::shared_ptr<GGMLBlock>(new Linear(intermediate_size, hidden_size, bias));
@ -84,7 +141,11 @@ namespace LLM {
auto down_proj = std::dynamic_pointer_cast<Linear>(blocks["down_proj"]);
auto h = gate_proj->forward(ctx, x);
if (activation == MLPActivation::GELU_TANH) {
h = ggml_ext_gelu(ctx->ggml_ctx, h, true);
} else {
h = ggml_silu_inplace(ctx->ggml_ctx, h);
}
h = ggml_mul_inplace(ctx->ggml_ctx, h, up_proj->forward(ctx, x));
h = down_proj->forward(ctx, h);
return h;
@ -377,24 +438,35 @@ namespace LLM {
int64_t num_heads;
int64_t num_kv_heads;
bool qk_norm;
int64_t max_position_embeddings;
std::vector<float> rope_thetas;
std::vector<float> rope_scales;
public:
Attention(const LLMParams& params)
: arch(params.arch), num_heads(params.num_heads), num_kv_heads(params.num_kv_heads), head_dim(params.head_dim), qk_norm(params.qk_norm) {
: arch(params.arch),
num_heads(params.num_heads),
num_kv_heads(params.num_kv_heads),
head_dim(params.head_dim),
qk_norm(params.qk_norm),
max_position_embeddings(params.max_position_embeddings),
rope_thetas(params.rope_thetas),
rope_scales(params.rope_scales) {
blocks["q_proj"] = std::make_shared<Linear>(params.hidden_size, num_heads * head_dim, params.qkv_bias);
blocks["k_proj"] = std::make_shared<Linear>(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias);
blocks["v_proj"] = std::make_shared<Linear>(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias);
blocks["o_proj"] = std::make_shared<Linear>(num_heads * head_dim, params.hidden_size, false);
if (params.qk_norm) {
blocks["q_norm"] = std::make_shared<RMSNorm>(head_dim, params.rms_norm_eps);
blocks["k_norm"] = std::make_shared<RMSNorm>(head_dim, params.rms_norm_eps);
blocks["q_norm"] = std::make_shared<LLMRMSNorm>(head_dim, params.rms_norm_eps, params.rms_norm_add);
blocks["k_norm"] = std::make_shared<LLMRMSNorm>(head_dim, params.rms_norm_eps, params.rms_norm_add);
}
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* input_pos,
ggml_tensor* attention_mask = nullptr) {
ggml_tensor* attention_mask = nullptr,
int rope_index = 0) {
// x: [N, n_token, hidden_size]
int64_t n_token = x->ne[1];
int64_t N = x->ne[2];
@ -412,8 +484,8 @@ namespace LLM {
v = ggml_reshape_4d(ctx->ggml_ctx, v, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim]
if (qk_norm) {
auto q_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["q_norm"]);
auto k_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["k_norm"]);
auto q_norm = std::dynamic_pointer_cast<LLMRMSNorm>(blocks["q_norm"]);
auto k_norm = std::dynamic_pointer_cast<LLMRMSNorm>(blocks["k_norm"]);
q = q_norm->forward(ctx, q);
k = k_norm->forward(ctx, k);
@ -428,6 +500,36 @@ namespace LLM {
} else if (arch == LLMArch::QWEN3) {
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
} else if (arch == LLMArch::GEMMA3_12B) {
float rope_theta = (rope_index == 1 ? 10000.0f : 1000000.0f);
float rope_scale = (rope_index == 1 ? 1.f : 8.f);
float freq_scale = 1.f / rope_scale;
q = ggml_rope_ext(ctx->ggml_ctx,
q,
input_pos,
nullptr,
head_dim,
GGML_ROPE_TYPE_NORMAL,
0,
rope_theta,
freq_scale,
0.f,
1.f,
32.f,
1.f);
k = ggml_rope_ext(ctx->ggml_ctx,
k,
input_pos,
nullptr,
head_dim,
GGML_ROPE_TYPE_NORMAL,
0,
rope_theta,
freq_scale,
0.f,
1.f,
32.f,
1.f);
} else {
int sections[4] = {16, 24, 24, 0};
q = ggml_rope_multi(ctx->ggml_ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
@ -448,32 +550,75 @@ namespace LLM {
};
struct TransformerBlock : public GGMLBlock {
protected:
LLMArch arch;
int sliding_attention;
bool has_post_attention_norm;
bool has_post_ffw_norm;
public:
TransformerBlock(const LLMParams& params) {
TransformerBlock(const LLMParams& params, int layer_index)
: arch(params.arch),
sliding_attention(0),
has_post_attention_norm(params.arch == LLMArch::GEMMA3_12B),
has_post_ffw_norm(params.arch == LLMArch::GEMMA3_12B) {
blocks["self_attn"] = std::make_shared<Attention>(params);
blocks["mlp"] = std::make_shared<MLP>(params.hidden_size, params.intermediate_size);
blocks["input_layernorm"] = std::make_shared<RMSNorm>(params.hidden_size, params.rms_norm_eps);
blocks["post_attention_layernorm"] = std::make_shared<RMSNorm>(params.hidden_size, params.rms_norm_eps);
blocks["mlp"] = std::make_shared<MLP>(params.hidden_size,
params.intermediate_size,
false,
params.mlp_activation);
blocks["input_layernorm"] = std::make_shared<LLMRMSNorm>(params.hidden_size, params.rms_norm_eps, params.rms_norm_add);
blocks["post_attention_layernorm"] = std::make_shared<LLMRMSNorm>(params.hidden_size, params.rms_norm_eps, params.rms_norm_add);
if (has_post_attention_norm) {
blocks["post_attention_norm"] = std::make_shared<LLMRMSNorm>(params.hidden_size, params.rms_norm_eps, params.rms_norm_add);
}
if (has_post_ffw_norm) {
blocks["post_ffw_norm"] = std::make_shared<LLMRMSNorm>(params.hidden_size, params.rms_norm_eps, params.rms_norm_add);
}
if (!params.sliding_attention.empty()) {
sliding_attention = params.sliding_attention[layer_index % params.sliding_attention.size()];
}
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* input_pos,
ggml_tensor* attention_mask = nullptr) {
ggml_tensor* attention_mask = nullptr,
ggml_tensor* sliding_attention_mask = nullptr) {
// x: [N, n_token, hidden_size]
auto self_attn = std::dynamic_pointer_cast<Attention>(blocks["self_attn"]);
auto mlp = std::dynamic_pointer_cast<MLP>(blocks["mlp"]);
auto input_layernorm = std::dynamic_pointer_cast<RMSNorm>(blocks["input_layernorm"]);
auto post_attention_layernorm = std::dynamic_pointer_cast<RMSNorm>(blocks["post_attention_layernorm"]);
auto input_layernorm = std::dynamic_pointer_cast<LLMRMSNorm>(blocks["input_layernorm"]);
auto post_attention_layernorm = std::dynamic_pointer_cast<LLMRMSNorm>(blocks["post_attention_layernorm"]);
std::shared_ptr<LLMRMSNorm> post_attention_norm = nullptr;
std::shared_ptr<LLMRMSNorm> post_ffw_norm = nullptr;
if (has_post_attention_norm) {
post_attention_norm = std::dynamic_pointer_cast<LLMRMSNorm>(blocks["post_attention_norm"]);
}
if (has_post_ffw_norm) {
post_ffw_norm = std::dynamic_pointer_cast<LLMRMSNorm>(blocks["post_ffw_norm"]);
}
ggml_tensor* block_attention_mask = attention_mask;
int rope_index = 0;
if (arch == LLMArch::GEMMA3_12B && sliding_attention > 0) {
block_attention_mask = sliding_attention_mask;
rope_index = 1;
}
auto residual = x;
x = input_layernorm->forward(ctx, x);
x = self_attn->forward(ctx, x, input_pos, attention_mask);
x = self_attn->forward(ctx, x, input_pos, block_attention_mask, rope_index);
if (post_attention_norm != nullptr) {
x = post_attention_norm->forward(ctx, x);
}
x = ggml_add_inplace(ctx->ggml_ctx, x, residual);
residual = x;
x = post_attention_layernorm->forward(ctx, x);
x = mlp->forward(ctx, x);
if (post_ffw_norm != nullptr) {
x = post_ffw_norm->forward(ctx, x);
}
x = ggml_add_inplace(ctx->ggml_ctx, x, residual);
return x;
@ -483,28 +628,36 @@ namespace LLM {
struct TextModel : public GGMLBlock {
protected:
int64_t num_layers;
int64_t hidden_size;
bool normalize_input;
float input_scale;
public:
TextModel(const LLMParams& params)
: num_layers(params.num_layers) {
: num_layers(params.num_layers),
hidden_size(params.hidden_size),
normalize_input(params.normalize_input),
input_scale(std::sqrt(static_cast<float>(params.hidden_size))) {
blocks["embed_tokens"] = std::shared_ptr<GGMLBlock>(new Embedding(params.vocab_size, params.hidden_size));
for (int i = 0; i < num_layers; i++) {
blocks["layers." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new TransformerBlock(params));
blocks["layers." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new TransformerBlock(params, i));
}
blocks["norm"] = std::shared_ptr<GGMLBlock>(new RMSNorm(params.hidden_size, params.rms_norm_eps));
blocks["norm"] = std::shared_ptr<GGMLBlock>(new LLMRMSNorm(params.hidden_size, params.rms_norm_eps, params.rms_norm_add));
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* input_ids,
ggml_tensor* input_pos,
ggml_tensor* attention_mask,
ggml_tensor* sliding_attention_mask,
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
std::set<int> out_layers) {
std::set<int> out_layers,
bool return_all_hidden_states = false) {
// input_ids: [N, n_token]
// return: [N, n_token, hidden_size]
auto embed_tokens = std::dynamic_pointer_cast<Embedding>(blocks["embed_tokens"]);
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["norm"]);
auto norm = std::dynamic_pointer_cast<LLMRMSNorm>(blocks["norm"]);
auto x = embed_tokens->forward(ctx, input_ids);
@ -549,22 +702,44 @@ namespace LLM {
x = input_embed;
}
if (normalize_input) {
x = ggml_ext_scale(ctx->ggml_ctx, x, input_scale, true);
}
if (return_all_hidden_states) {
intermediate_outputs.push_back(x);
}
for (int i = 0; i < num_layers; i++) {
auto block = std::dynamic_pointer_cast<TransformerBlock>(blocks["layers." + std::to_string(i)]);
x = block->forward(ctx, x, input_pos, attention_mask);
if (out_layers.find(i + 1) != out_layers.end()) {
x = block->forward(ctx, x, input_pos, attention_mask, sliding_attention_mask);
if (return_all_hidden_states) {
if (i + 1 < num_layers) {
intermediate_outputs.push_back(x);
}
} else if (out_layers.find(i + 1) != out_layers.end()) {
intermediate_outputs.push_back(x);
}
}
if (!intermediate_outputs.empty()) {
auto normed_x = norm->forward(ctx, x);
if (return_all_hidden_states) {
intermediate_outputs.push_back(normed_x);
x = intermediate_outputs[0];
for (int i = 1; i < intermediate_outputs.size(); i++) {
x = ggml_concat(ctx->ggml_ctx, x, intermediate_outputs[i], 0);
}
} else if (!intermediate_outputs.empty()) {
if (out_layers.find(static_cast<int>(num_layers + 1)) != out_layers.end()) {
intermediate_outputs.push_back(normed_x);
}
x = intermediate_outputs[0];
for (int i = 1; i < intermediate_outputs.size(); i++) {
x = ggml_concat(ctx->ggml_ctx, x, intermediate_outputs[i], 0);
}
} else {
x = norm->forward(ctx, x);
x = normed_x;
}
return x;
}
@ -599,12 +774,21 @@ namespace LLM {
ggml_tensor* input_ids,
ggml_tensor* input_pos,
ggml_tensor* attention_mask,
ggml_tensor* sliding_attention_mask,
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
std::set<int> out_layers) {
std::set<int> out_layers,
bool return_all_hidden_states = false) {
// input_ids: [N, n_token]
auto model = std::dynamic_pointer_cast<TextModel>(blocks["model"]);
auto x = model->forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers);
auto x = model->forward(ctx,
input_ids,
input_pos,
attention_mask,
sliding_attention_mask,
image_embeds,
out_layers,
return_all_hidden_states);
return x;
}
@ -627,6 +811,7 @@ namespace LLM {
std::vector<int> input_pos_vec;
std::vector<float> attention_mask_vec;
std::vector<float> sliding_attention_mask_vec;
std::vector<float> window_mask_vec;
std::vector<int> window_index_vec;
std::vector<int> window_inverse_index_vec;
@ -653,6 +838,23 @@ namespace LLM {
params.qkv_bias = false;
params.qk_norm = true;
params.rms_norm_eps = 1e-6f;
} else if (arch == LLMArch::GEMMA3_12B) {
params.head_dim = 256;
params.num_heads = 16;
params.num_kv_heads = 8;
params.qkv_bias = false;
params.qk_norm = true;
params.rms_norm_eps = 1e-6f;
// llama.cpp adds +1 to Gemma3 norm.weight when exporting GGUF, so GGUF loading
// must keep rms_norm_add disabled here or the offset gets applied twice.
// Convenient for the converter, less convenient for whoever gets to debug it later.
params.rms_norm_add = false;
params.normalize_input = true;
params.max_position_embeddings = 131072;
params.mlp_activation = MLPActivation::GELU_TANH;
params.rope_thetas = {1000000.f, 10000.f};
params.rope_scales = {8.f, 1.f};
params.sliding_attention = {1024, 1024, 1024, 1024, 1024, 0};
}
bool have_vision_weight = false;
bool llama_cpp_style = false;
@ -722,9 +924,18 @@ namespace LLM {
ggml_tensor* input_ids,
ggml_tensor* input_pos,
ggml_tensor* attention_mask,
ggml_tensor* sliding_attention_mask,
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
std::set<int> out_layers) {
auto hidden_states = model.forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers); // [N, n_token, hidden_size]
std::set<int> out_layers,
bool return_all_hidden_states = false) {
auto hidden_states = model.forward(ctx,
input_ids,
input_pos,
attention_mask,
sliding_attention_mask,
image_embeds,
out_layers,
return_all_hidden_states); // [N, n_token, hidden_size]
return hidden_states;
}
@ -741,8 +952,9 @@ namespace LLM {
ggml_cgraph* build_graph(const sd::Tensor<int32_t>& input_ids_tensor,
const sd::Tensor<float>& attention_mask_tensor,
const std::vector<std::pair<int, sd::Tensor<float>>>& image_embeds_tensor,
std::set<int> out_layers) {
ggml_cgraph* gf = ggml_new_graph(compute_ctx);
std::set<int> out_layers,
bool return_all_hidden_states = false) {
ggml_cgraph* gf = new_graph_custom(LLM_GRAPH_SIZE);
ggml_tensor* input_ids = make_input(input_ids_tensor);
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
image_embeds.reserve(image_embeds_tensor.size());
@ -752,7 +964,10 @@ namespace LLM {
}
int64_t n_tokens = input_ids->ne[0];
if (params.arch == LLMArch::MISTRAL_SMALL_3_2 || params.arch == LLMArch::MINISTRAL_3_3B || params.arch == LLMArch::QWEN3) {
if (params.arch == LLMArch::MISTRAL_SMALL_3_2 ||
params.arch == LLMArch::MINISTRAL_3_3B ||
params.arch == LLMArch::QWEN3 ||
params.arch == LLMArch::GEMMA3_12B) {
input_pos_vec.resize(n_tokens);
for (int i = 0; i < n_tokens; ++i) {
input_pos_vec[i] = i;
@ -773,6 +988,7 @@ namespace LLM {
set_backend_tensor_data(input_pos, input_pos_vec.data());
ggml_tensor* attention_mask = nullptr;
ggml_tensor* sliding_attention_mask = nullptr;
if (!attention_mask_tensor.empty()) {
attention_mask = make_input(attention_mask_tensor);
} else {
@ -790,9 +1006,36 @@ namespace LLM {
set_backend_tensor_data(attention_mask, attention_mask_vec.data());
}
if (params.arch == LLMArch::GEMMA3_12B) {
sliding_attention_mask_vec.resize(n_tokens * n_tokens);
if (!attention_mask_tensor.empty()) {
GGML_ASSERT(attention_mask_tensor.numel() == n_tokens * n_tokens);
sliding_attention_mask_vec = attention_mask_tensor.values();
} else {
sliding_attention_mask_vec = attention_mask_vec;
}
for (int i0 = 0; i0 < n_tokens; i0++) {
for (int i1 = 0; i1 < n_tokens; i1++) {
if (i0 + 1024 <= i1) {
LOG_DEBUG("xxxxxxxxxxxxxx");
sliding_attention_mask_vec[i1 * n_tokens + i0] = -INFINITY;
}
}
}
sliding_attention_mask = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, n_tokens, n_tokens);
set_backend_tensor_data(sliding_attention_mask, sliding_attention_mask_vec.data());
}
auto runner_ctx = get_context();
ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers);
ggml_tensor* hidden_states = forward(&runner_ctx,
input_ids,
input_pos,
attention_mask,
sliding_attention_mask,
image_embeds,
out_layers,
return_all_hidden_states);
ggml_build_forward_expand(gf, hidden_states);
@ -803,9 +1046,14 @@ namespace LLM {
const sd::Tensor<int32_t>& input_ids,
const sd::Tensor<float>& attention_mask,
const std::vector<std::pair<int, sd::Tensor<float>>>& image_embeds,
std::set<int> out_layers) {
std::set<int> out_layers,
bool return_all_hidden_states = false) {
auto get_graph = [&]() -> ggml_cgraph* {
return build_graph(input_ids, attention_mask, image_embeds, out_layers);
return build_graph(input_ids,
attention_mask,
image_embeds,
out_layers,
return_all_hidden_states);
};
return take_or_empty(GGMLRunner::compute<float>(get_graph, n_threads, true));
}

970
src/ltx_vae.h Normal file
View File

@ -0,0 +1,970 @@
#ifndef __SD_LTX_VAE_H__
#define __SD_LTX_VAE_H__
#include <fstream>
#include <memory>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "ltxv.hpp"
#include "vae.hpp"
#include "wan.hpp"
namespace LTXVAE {
static inline ggml_tensor* apply_scale_shift(ggml_context* ctx,
ggml_tensor* x,
ggml_tensor* scale,
ggml_tensor* shift) {
x = ggml_add(ctx, x, ggml_mul(ctx, x, scale));
x = ggml_add(ctx, x, shift);
return x;
}
static inline ggml_tensor* reshape_channel_broadcast(ggml_context* ctx,
ggml_tensor* x) {
return ggml_reshape_4d(ctx, x, 1, 1, 1, ggml_nelements(x));
}
static inline std::pair<ggml_tensor*, ggml_tensor*> get_shift_scale(ggml_context* ctx,
ggml_tensor* table,
ggml_tensor* timestep,
int64_t channels,
int parts) {
GGML_ASSERT(timestep != nullptr);
GGML_ASSERT(ggml_nelements(timestep) == channels * parts);
auto timestep_view = ggml_reshape_2d(ctx, timestep, channels, parts);
auto values = ggml_add(ctx, table, timestep_view);
auto chunks = ggml_ext_chunk(ctx, values, parts, 1, false);
auto shift = reshape_channel_broadcast(ctx, ggml_cont(ctx, chunks[0]));
auto scale = reshape_channel_broadcast(ctx, ggml_cont(ctx, chunks[1]));
return {shift, scale};
}
static inline ggml_tensor* depth_to_space_3d(ggml_context* ctx,
ggml_tensor* x,
int64_t c,
int factor_t,
int factor_s,
bool drop_first_temporal_frame) {
// x: [B*c*p1*p2*p3, T, H, W], B == 1, p2 == p3 == factor_s, p1 == factor_t
// return: [B*c, T*p1, H*p2, W*p2]
// Match: rearrange(x, "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)")
const int64_t T = x->ne[2];
const int64_t H = x->ne[1];
const int64_t W = x->ne[0];
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 1, 3, 2)); // [T, C, H, W]
x = ggml_reshape_4d(ctx, x, W, H, factor_s, factor_s * factor_t * c * T); // [T*c*p1*p2, p3, H, W]
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [T*c*p1*p2, H, W, p3]
x = ggml_reshape_4d(ctx, x, factor_s * W, H, factor_s, factor_t * c * T); // [T*c*p1, p2, H, W*p3]
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [T*c*p1, H, p2, W*p3]
x = ggml_reshape_4d(ctx, x, factor_s * W * factor_s * H, factor_t, c, T); // [T, c, p1, H*p2*W*p3]
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 1, 3, 2)); // [c, T, p1, H*p2*W*p3]
x = ggml_reshape_4d(ctx, x, factor_s * W, factor_s * H, factor_t * T, c); // [T, c, T*p1, H*p2*W*p3]
if (drop_first_temporal_frame && factor_t > 1 && x->ne[2] > 0) {
x = ggml_ext_slice(ctx, x, 2, 1, x->ne[2]);
}
return x;
}
static inline ggml_tensor* patchify(ggml_context* ctx,
ggml_tensor* x,
int patch_size) {
return WAN::WanVAE::patchify(ctx, x, patch_size, 1);
}
class CausalConv3d : public GGMLBlock {
protected:
int time_kernel_size;
public:
CausalConv3d(int64_t in_channels,
int64_t out_channels,
int kernel_size = 3,
std::tuple<int, int, int> stride = {1, 1, 1},
int dilation = 1,
bool bias = true) {
time_kernel_size = kernel_size;
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv3d(in_channels,
out_channels,
{kernel_size, kernel_size, kernel_size},
stride,
{0, kernel_size / 2, kernel_size / 2},
{dilation, 1, 1},
bias));
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
bool causal = true) {
// x: [B*C, T, H, W], B == 1
auto conv = std::dynamic_pointer_cast<Conv3d>(blocks["conv"]);
if (causal) {
auto first_frame = ggml_ext_slice(ctx->ggml_ctx, x, 2, 0, 1);
auto first_frame_pad = first_frame;
for (int i = 1; i < time_kernel_size - 1; i++) {
first_frame_pad = ggml_concat(ctx->ggml_ctx, first_frame_pad, first_frame, 2);
}
x = ggml_concat(ctx->ggml_ctx, first_frame_pad, x, 2);
} else {
auto first_frame = ggml_ext_slice(ctx->ggml_ctx, x, 2, 0, 1);
auto first_frame_pad = first_frame;
for (int i = 1; i < (time_kernel_size - 1) / 2; i++) {
first_frame_pad = ggml_concat(ctx->ggml_ctx, first_frame_pad, first_frame, 2);
}
auto last_frame = ggml_ext_slice(ctx->ggml_ctx, x, 2, x->ne[2] - 1, x->ne[2]);
auto last_frame_pad = last_frame;
for (int i = 1; i < (time_kernel_size - 1) / 2; i++) {
last_frame_pad = ggml_concat(ctx->ggml_ctx, last_frame_pad, last_frame, 2);
}
x = ggml_concat(ctx->ggml_ctx, first_frame_pad, x, 2);
x = ggml_concat(ctx->ggml_ctx, x, last_frame_pad, 2);
}
return conv->forward(ctx, x);
}
};
struct PixelNorm3D : public UnaryBlock {
float eps;
PixelNorm3D(float eps = 1e-8f)
: eps(eps) {}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override {
auto h = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 3, 0, 1, 2));
h = ggml_rms_norm(ctx->ggml_ctx, h, eps);
h = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, h, 1, 2, 3, 0));
return h;
}
};
struct PixArtAlphaCombinedTimestepSizeEmbeddings : public GGMLBlock {
int64_t embedding_dim;
PixArtAlphaCombinedTimestepSizeEmbeddings(int64_t embedding_dim)
: embedding_dim(embedding_dim) {
blocks["timestep_embedder"] = std::make_shared<LTXV::TimestepEmbedder>(embedding_dim);
}
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* timestep) {
auto timestep_embedder = std::dynamic_pointer_cast<LTXV::TimestepEmbedder>(blocks["timestep_embedder"]);
return timestep_embedder->forward(ctx, timestep);
}
};
struct ResnetBlock3D : public GGMLBlock {
int64_t channels;
bool timestep_conditioning;
protected:
void init_params(ggml_context* ctx,
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "") override {
if (timestep_conditioning) {
params["scale_shift_table"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, channels, 4);
}
}
public:
ResnetBlock3D(int64_t channels,
float eps = 1e-6f,
bool timestep_conditioning = false)
: channels(channels), timestep_conditioning(timestep_conditioning) {
blocks["norm1"] = std::make_shared<PixelNorm3D>(eps);
blocks["conv1"] = std::make_shared<CausalConv3d>(channels, channels, 3);
blocks["norm2"] = std::make_shared<PixelNorm3D>(eps);
blocks["conv2"] = std::make_shared<CausalConv3d>(channels, channels, 3);
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* timestep = nullptr,
bool causal = false) {
auto norm1 = std::dynamic_pointer_cast<PixelNorm3D>(blocks["norm1"]);
auto conv1 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv1"]);
auto norm2 = std::dynamic_pointer_cast<PixelNorm3D>(blocks["norm2"]);
auto conv2 = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv2"]);
ggml_tensor* shift1 = nullptr;
ggml_tensor* scale1 = nullptr;
ggml_tensor* shift2 = nullptr;
ggml_tensor* scale2 = nullptr;
if (timestep_conditioning) {
GGML_ASSERT(timestep != nullptr);
auto values = ggml_add(ctx->ggml_ctx,
params["scale_shift_table"],
ggml_reshape_2d(ctx->ggml_ctx, timestep, channels, 4));
auto chunks = ggml_ext_chunk(ctx->ggml_ctx, values, 4, 1, false);
shift1 = reshape_channel_broadcast(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, chunks[0]));
scale1 = reshape_channel_broadcast(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, chunks[1]));
shift2 = reshape_channel_broadcast(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, chunks[2]));
scale2 = reshape_channel_broadcast(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, chunks[3]));
}
auto h = norm1->forward(ctx, x);
if (timestep_conditioning) {
h = apply_scale_shift(ctx->ggml_ctx, h, scale1, shift1);
}
h = ggml_silu_inplace(ctx->ggml_ctx, h);
h = conv1->forward(ctx, h, causal);
h = norm2->forward(ctx, h);
if (timestep_conditioning) {
h = apply_scale_shift(ctx->ggml_ctx, h, scale2, shift2);
}
h = ggml_silu_inplace(ctx->ggml_ctx, h);
h = conv2->forward(ctx, h, causal);
return ggml_add(ctx->ggml_ctx, h, x);
}
};
struct UNetMidBlock3D : public GGMLBlock {
int64_t channels;
int num_layers;
bool timestep_conditioning;
UNetMidBlock3D(int64_t channels,
int num_layers,
bool timestep_conditioning)
: channels(channels),
num_layers(num_layers),
timestep_conditioning(timestep_conditioning) {
if (timestep_conditioning) {
blocks["time_embedder"] = std::make_shared<PixArtAlphaCombinedTimestepSizeEmbeddings>(channels * 4);
}
for (int i = 0; i < num_layers; i++) {
blocks["res_blocks." + std::to_string(i)] = std::make_shared<ResnetBlock3D>(channels, 1e-6f, timestep_conditioning);
}
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* timestep = nullptr,
bool causal = false) {
ggml_tensor* timestep_embed = nullptr;
if (timestep_conditioning) {
GGML_ASSERT(timestep != nullptr);
auto time_embedder = std::dynamic_pointer_cast<PixArtAlphaCombinedTimestepSizeEmbeddings>(blocks["time_embedder"]);
timestep_embed = time_embedder->forward(ctx, timestep);
}
for (int i = 0; i < num_layers; i++) {
auto resnet = std::dynamic_pointer_cast<ResnetBlock3D>(blocks["res_blocks." + std::to_string(i)]);
x = resnet->forward(ctx, x, timestep_embed, causal);
}
return x;
}
};
struct DepthToSpaceUpsample : public GGMLBlock {
int64_t in_channels;
int factor_t;
int factor_s;
int out_channels_reduction_factor;
bool residual;
DepthToSpaceUpsample(int64_t in_channels,
int factor_t = 2,
int factor_s = 2,
int out_channels_reduction_factor = 2,
bool residual = true)
: in_channels(in_channels),
factor_t(factor_t),
factor_s(factor_s),
out_channels_reduction_factor(out_channels_reduction_factor),
residual(residual) {
const int64_t factor = static_cast<int64_t>(factor_t) * static_cast<int64_t>(factor_s) * static_cast<int64_t>(factor_s);
const int64_t out_dim = (factor * in_channels) / out_channels_reduction_factor;
blocks["conv"] = std::make_shared<CausalConv3d>(in_channels, out_dim, 3);
}
int64_t get_output_channels() const {
return in_channels / out_channels_reduction_factor;
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
bool causal = false) {
auto conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv"]);
ggml_tensor* x_in = nullptr;
if (residual) {
x_in = depth_to_space_3d(ctx->ggml_ctx, x, in_channels / (factor_t * factor_s * factor_s), factor_t, factor_s, factor_t > 1);
int repeat = (factor_t * factor_s * factor_s) / out_channels_reduction_factor;
auto res = x_in;
for (int i = 1; i < repeat; i++) {
res = ggml_concat(ctx->ggml_ctx, res, x_in, 3);
}
x_in = res;
}
x = conv->forward(ctx, x, causal);
x = depth_to_space_3d(ctx->ggml_ctx, x, get_output_channels(), factor_t, factor_s, factor_t > 1);
if (residual) {
x = ggml_add(ctx->ggml_ctx, x, x_in);
}
return x;
}
};
struct SpaceToDepthDownsample : public GGMLBlock {
int64_t in_channels;
int64_t out_channels;
int factor_t;
int factor_s;
SpaceToDepthDownsample(int64_t in_channels,
int64_t out_channels,
int factor_t,
int factor_s)
: in_channels(in_channels),
out_channels(out_channels),
factor_t(factor_t),
factor_s(factor_s) {
const int64_t factor = static_cast<int64_t>(factor_t) * static_cast<int64_t>(factor_s) * static_cast<int64_t>(factor_s);
GGML_ASSERT(out_channels % factor == 0);
blocks["conv"] = std::make_shared<CausalConv3d>(in_channels, out_channels / factor, 3);
blocks["skip_downsample"] = std::make_shared<WAN::AvgDown3D>(in_channels, out_channels, factor_t, factor_s);
blocks["conv_downsample"] = std::make_shared<WAN::AvgDown3D>(out_channels / factor, out_channels, factor_t, factor_s);
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
bool causal = true) {
auto conv = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv"]);
auto skip_downsample = std::dynamic_pointer_cast<WAN::AvgDown3D>(blocks["skip_downsample"]);
auto conv_downsample = std::dynamic_pointer_cast<WAN::AvgDown3D>(blocks["conv_downsample"]);
if (factor_t > 1 && x->ne[2] > 0) {
auto first_frame = ggml_ext_slice(ctx->ggml_ctx, x, 2, 0, 1);
auto first_frame_pad = first_frame;
for (int i = 1; i < factor_t; ++i) {
first_frame_pad = ggml_concat(ctx->ggml_ctx, first_frame_pad, first_frame, 2);
}
x = ggml_concat(ctx->ggml_ctx, first_frame_pad, x, 2);
}
auto residual = skip_downsample->forward(ctx, x);
auto h = conv->forward(ctx, x, causal);
h = conv_downsample->forward(ctx, h);
return ggml_add(ctx->ggml_ctx, h, residual);
}
};
struct PerChannelStatistics : public GGMLBlock {
protected:
void init_params(ggml_context* ctx,
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "") override {
params["std-of-means"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 128);
params["mean-of-means"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 128);
}
public:
ggml_tensor* un_normalize(GGMLRunnerContext* ctx,
ggml_tensor* x) {
auto std_tensor = reshape_channel_broadcast(ctx->ggml_ctx, params["std-of-means"]);
auto mean_tensor = reshape_channel_broadcast(ctx->ggml_ctx, params["mean-of-means"]);
return ggml_add(ctx->ggml_ctx, ggml_mul(ctx->ggml_ctx, x, std_tensor), mean_tensor);
}
ggml_tensor* normalize(GGMLRunnerContext* ctx,
ggml_tensor* x) {
auto std_tensor = reshape_channel_broadcast(ctx->ggml_ctx, params["std-of-means"]);
auto mean_tensor = reshape_channel_broadcast(ctx->ggml_ctx, params["mean-of-means"]);
return ggml_div(ctx->ggml_ctx, ggml_sub(ctx->ggml_ctx, x, mean_tensor), std_tensor);
}
};
struct DecoderConfig {
struct Block {
std::string type;
int num_layers = 0;
int multiplier = 1;
};
std::vector<Block> blocks;
};
struct EncoderConfig {
struct Block {
std::string type;
int num_layers = 0;
int multiplier = 1;
};
std::vector<Block> blocks;
};
static inline bool has_tensor(const String2TensorStorage& tensor_storage_map,
const std::string& name) {
return tensor_storage_map.find(name) != tensor_storage_map.end();
}
static inline int64_t get_tensor_ne0(const String2TensorStorage& tensor_storage_map,
const std::string& name,
int64_t fallback = 0) {
auto iter = tensor_storage_map.find(name);
if (iter == tensor_storage_map.end()) {
return fallback;
}
return iter->second.ne[0];
}
static inline DecoderConfig infer_decoder_config_from_weights(const String2TensorStorage& tensor_storage_map,
const std::string& prefix,
int64_t conv_in_channels) {
DecoderConfig cfg;
const std::string decoder_prefix = prefix + ".decoder.up_blocks.";
int64_t current_channels = conv_in_channels;
for (int block_idx = 0;; ++block_idx) {
const std::string block_prefix = decoder_prefix + std::to_string(block_idx);
const std::string res0_bias = block_prefix + ".res_blocks.0.conv1.conv.bias";
const std::string conv_bias = block_prefix + ".conv.conv.bias";
if (has_tensor(tensor_storage_map, res0_bias)) {
int num_layers = 0;
while (has_tensor(tensor_storage_map,
block_prefix + ".res_blocks." + std::to_string(num_layers) + ".conv1.conv.bias")) {
num_layers++;
}
cfg.blocks.push_back({"res_x", num_layers, 1});
current_channels = get_tensor_ne0(tensor_storage_map, res0_bias, current_channels);
continue;
}
if (!has_tensor(tensor_storage_map, conv_bias)) {
break;
}
int64_t next_channels = 0;
for (int next_idx = block_idx + 1;; ++next_idx) {
const std::string next_res0_bias = decoder_prefix + std::to_string(next_idx) + ".res_blocks.0.conv1.conv.bias";
const std::string next_conv_bias = decoder_prefix + std::to_string(next_idx) + ".conv.conv.bias";
if (has_tensor(tensor_storage_map, next_res0_bias)) {
next_channels = get_tensor_ne0(tensor_storage_map, next_res0_bias);
break;
}
if (!has_tensor(tensor_storage_map, next_conv_bias)) {
break;
}
}
if (next_channels <= 0 || current_channels % next_channels != 0) {
next_channels = std::max<int64_t>(1, current_channels / 2);
}
const int64_t conv_out_dim = get_tensor_ne0(tensor_storage_map, conv_bias);
const int64_t reduction = std::max<int64_t>(1, current_channels / next_channels);
const int64_t factor = next_channels > 0 ? conv_out_dim / next_channels : 0;
if (factor == 8) {
cfg.blocks.push_back({"compress_all", 0, static_cast<int>(reduction)});
} else if (factor == 4) {
cfg.blocks.push_back({"compress_space", 0, static_cast<int>(reduction)});
} else if (factor == 2) {
cfg.blocks.push_back({"compress_time", 0, static_cast<int>(reduction)});
} else {
LOG_WARN("unexpected LTX VAE upsample factor at '%s': conv_out=%lld current=%lld next=%lld, falling back to compress_all x%d",
block_prefix.c_str(),
(long long)conv_out_dim,
(long long)current_channels,
(long long)next_channels,
(int)reduction);
cfg.blocks.push_back({"compress_all", 0, static_cast<int>(reduction)});
}
current_channels = next_channels;
}
return cfg;
}
static inline int detect_ltx_vae_version(const String2TensorStorage& tensor_storage_map,
const std::string& prefix) {
const std::string v2_probe = prefix + ".encoder.down_blocks.1.conv.conv.bias";
if (tensor_storage_map.find(v2_probe) != tensor_storage_map.end()) {
return 2;
}
return 1;
}
static inline bool detect_ltx_vae_timestep_conditioning(const String2TensorStorage& tensor_storage_map,
const std::string& prefix) {
return tensor_storage_map.find(prefix + ".decoder.timestep_scale_multiplier") != tensor_storage_map.end();
}
static inline EncoderConfig get_encoder_config(int version) {
EncoderConfig cfg;
if (version < 2) {
GGML_ABORT("LTX VAE encoder is only implemented for version >= 2");
}
cfg.blocks = {
{"res_x", 4, 1},
{"compress_space_res", 0, 2},
{"res_x", 6, 1},
{"compress_time_res", 0, 2},
{"res_x", 6, 1},
{"compress_all_res", 0, 2},
{"res_x", 2, 1},
{"compress_all_res", 0, 2},
{"res_x", 2, 1},
};
return cfg;
}
struct Encoder : public GGMLBlock {
int version;
int patch_size;
int64_t in_channels;
int64_t latent_channels;
Encoder(int version,
int patch_size = 4,
int64_t in_channels = 3,
int64_t latent_channels = 128)
: version(version),
patch_size(patch_size),
in_channels(in_channels),
latent_channels(latent_channels) {
auto cfg = get_encoder_config(version);
int64_t channels = 128;
int64_t in_dim = in_channels * patch_size * patch_size;
blocks["conv_in"] = std::make_shared<CausalConv3d>(in_dim, channels, 3);
for (int block_idx = 0; block_idx < static_cast<int>(cfg.blocks.size()); ++block_idx) {
const auto& block = cfg.blocks[block_idx];
if (block.type == "res_x") {
blocks["down_blocks." + std::to_string(block_idx)] = std::make_shared<UNetMidBlock3D>(channels,
block.num_layers,
false);
} else if (block.type == "compress_space_res") {
int64_t next_channels = channels * block.multiplier;
blocks["down_blocks." + std::to_string(block_idx)] = std::make_shared<SpaceToDepthDownsample>(channels,
next_channels,
1,
2);
channels = next_channels;
} else if (block.type == "compress_time_res") {
int64_t next_channels = channels * block.multiplier;
blocks["down_blocks." + std::to_string(block_idx)] = std::make_shared<SpaceToDepthDownsample>(channels,
next_channels,
2,
1);
channels = next_channels;
} else if (block.type == "compress_all_res") {
int64_t next_channels = channels * block.multiplier;
blocks["down_blocks." + std::to_string(block_idx)] = std::make_shared<SpaceToDepthDownsample>(channels,
next_channels,
2,
2);
channels = next_channels;
} else {
GGML_ABORT("Unsupported LTX VAE encoder block");
}
}
blocks["conv_norm_out"] = std::make_shared<PixelNorm3D>();
blocks["conv_out"] = std::make_shared<CausalConv3d>(channels, latent_channels + 1, 3);
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x) {
auto conv_in = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv_in"]);
auto conv_norm_out = std::dynamic_pointer_cast<PixelNorm3D>(blocks["conv_norm_out"]);
auto conv_out = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv_out"]);
x = conv_in->forward(ctx, x, true);
int block_idx = 0;
while (blocks.find("down_blocks." + std::to_string(block_idx)) != blocks.end()) {
auto mid_block = std::dynamic_pointer_cast<UNetMidBlock3D>(blocks["down_blocks." + std::to_string(block_idx)]);
if (mid_block) {
x = mid_block->forward(ctx, x, nullptr, true);
} else {
auto downsample = std::dynamic_pointer_cast<SpaceToDepthDownsample>(blocks["down_blocks." + std::to_string(block_idx)]);
x = downsample->forward(ctx, x, true);
}
block_idx++;
}
x = conv_norm_out->forward(ctx, x);
x = ggml_silu_inplace(ctx->ggml_ctx, x);
x = conv_out->forward(ctx, x, true);
auto last_channel = ggml_ext_slice(ctx->ggml_ctx, x, 3, x->ne[3] - 1, x->ne[3]);
auto repeat_shape = ggml_new_tensor_4d(ctx->ggml_ctx, last_channel->type, last_channel->ne[0], last_channel->ne[1], last_channel->ne[2], latent_channels - 1);
auto repeated = ggml_repeat(ctx->ggml_ctx, last_channel, repeat_shape);
return ggml_concat(ctx->ggml_ctx, x, repeated, 3);
}
};
struct Decoder : public GGMLBlock {
int version;
int patch_size;
bool causal_decoder;
bool timestep_conditioning;
int64_t in_channels;
int64_t hidden_channels;
protected:
void init_params(ggml_context* ctx,
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "") override {
if (timestep_conditioning) {
params["timestep_scale_multiplier"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
params["last_scale_shift_table"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hidden_channels, 2);
}
}
public:
Decoder(int version,
const String2TensorStorage& tensor_storage_map,
const std::string& prefix,
int patch_size = 4,
bool causal_decoder = false,
bool timestep_conditioning = true,
int64_t in_channels = 128,
int64_t hidden_channels = 128)
: version(version),
patch_size(patch_size),
causal_decoder(causal_decoder),
timestep_conditioning(timestep_conditioning),
in_channels(in_channels),
hidden_channels(hidden_channels) {
const int64_t conv_in_out_channels = get_tensor_ne0(tensor_storage_map,
prefix + ".decoder.conv_in.conv.bias",
hidden_channels);
auto cfg = infer_decoder_config_from_weights(tensor_storage_map,
prefix,
conv_in_out_channels);
int64_t channels = conv_in_out_channels;
blocks["conv_in"] = std::make_shared<CausalConv3d>(in_channels, channels, 3);
for (int block_idx = 0; block_idx < static_cast<int>(cfg.blocks.size()); ++block_idx) {
const auto& block = cfg.blocks[block_idx];
if (block.type == "res_x") {
blocks["up_blocks." + std::to_string(block_idx)] = std::make_shared<UNetMidBlock3D>(channels,
block.num_layers,
timestep_conditioning);
} else if (block.type == "compress_all") {
blocks["up_blocks." + std::to_string(block_idx)] = std::make_shared<DepthToSpaceUpsample>(channels,
2,
2,
block.multiplier,
false);
channels /= block.multiplier;
} else if (block.type == "compress_time") {
blocks["up_blocks." + std::to_string(block_idx)] = std::make_shared<DepthToSpaceUpsample>(channels,
2,
1,
block.multiplier,
false);
channels /= block.multiplier;
} else if (block.type == "compress_space") {
blocks["up_blocks." + std::to_string(block_idx)] = std::make_shared<DepthToSpaceUpsample>(channels,
1,
2,
block.multiplier,
false);
channels /= block.multiplier;
} else {
GGML_ABORT("Unsupported LTX VAE decoder block");
}
}
hidden_channels = channels;
blocks["conv_norm_out"] = std::make_shared<PixelNorm3D>();
blocks["conv_out"] = std::make_shared<CausalConv3d>(hidden_channels, 3 * patch_size * patch_size, 3);
if (timestep_conditioning) {
blocks["last_time_embedder"] = std::make_shared<PixArtAlphaCombinedTimestepSizeEmbeddings>(hidden_channels * 2);
}
}
ggml_tensor* forward(GGMLRunnerContext* ctx,
ggml_tensor* x,
ggml_tensor* timestep) {
auto conv_in = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv_in"]);
auto conv_norm_out = std::dynamic_pointer_cast<PixelNorm3D>(blocks["conv_norm_out"]);
auto conv_out = std::dynamic_pointer_cast<CausalConv3d>(blocks["conv_out"]);
ggml_tensor* scaled_timestep = timestep;
if (timestep_conditioning) {
auto multiplier = ggml_ext_backend_tensor_get_f32(params["timestep_scale_multiplier"]);
scaled_timestep = ggml_ext_scale(ctx->ggml_ctx, timestep, multiplier);
}
x = conv_in->forward(ctx, x, causal_decoder);
int block_idx = 0;
while (blocks.find("up_blocks." + std::to_string(block_idx)) != blocks.end()) {
auto mid_block = std::dynamic_pointer_cast<UNetMidBlock3D>(blocks["up_blocks." + std::to_string(block_idx)]);
if (mid_block) {
x = mid_block->forward(ctx, x, scaled_timestep, causal_decoder);
} else {
auto upsample = std::dynamic_pointer_cast<DepthToSpaceUpsample>(blocks["up_blocks." + std::to_string(block_idx)]);
x = upsample->forward(ctx, x, causal_decoder);
}
block_idx++;
}
x = conv_norm_out->forward(ctx, x);
if (timestep_conditioning) {
auto last_time_embedder = std::dynamic_pointer_cast<PixArtAlphaCombinedTimestepSizeEmbeddings>(blocks["last_time_embedder"]);
auto timestep_embed = last_time_embedder->forward(ctx, scaled_timestep);
auto [shift, scale] = get_shift_scale(ctx->ggml_ctx,
params["last_scale_shift_table"],
timestep_embed,
hidden_channels,
2);
x = apply_scale_shift(ctx->ggml_ctx, x, scale, shift);
}
x = ggml_silu_inplace(ctx->ggml_ctx, x);
x = conv_out->forward(ctx, x, causal_decoder);
return x;
}
};
struct VideoVAE : public GGMLBlock {
int version;
float decode_timestep;
bool timestep_conditioning;
int patch_size;
bool decode_only;
VideoVAE(int version,
bool decode_only,
bool timestep_conditioning,
int patch_size,
const String2TensorStorage& tensor_storage_map,
const std::string& prefix,
float decode_timestep = 0.05f)
: version(version),
decode_timestep(decode_timestep),
timestep_conditioning(timestep_conditioning),
patch_size(patch_size),
decode_only(decode_only) {
if (!decode_only) {
blocks["encoder"] = std::make_shared<Encoder>(version, patch_size);
}
blocks["decoder"] = std::make_shared<Decoder>(version,
tensor_storage_map,
prefix,
patch_size,
false,
timestep_conditioning);
blocks["per_channel_statistics"] = std::make_shared<PerChannelStatistics>();
}
ggml_tensor* decode(GGMLRunnerContext* ctx,
ggml_tensor* z,
ggml_tensor* timestep) {
auto decoder = std::dynamic_pointer_cast<Decoder>(blocks["decoder"]);
auto processor = std::dynamic_pointer_cast<PerChannelStatistics>(blocks["per_channel_statistics"]);
auto latents = processor->un_normalize(ctx, z);
auto out = decoder->forward(ctx, latents, timestep);
out = WAN::WanVAE::unpatchify(ctx->ggml_ctx, out, patch_size, 1);
return out;
}
ggml_tensor* encode(GGMLRunnerContext* ctx,
ggml_tensor* x) {
GGML_ASSERT(!decode_only);
auto encoder = std::dynamic_pointer_cast<Encoder>(blocks["encoder"]);
auto processor = std::dynamic_pointer_cast<PerChannelStatistics>(blocks["per_channel_statistics"]);
x = patchify(ctx->ggml_ctx, x, patch_size);
auto out = encoder->forward(ctx, x);
auto mean = ggml_ext_chunk(ctx->ggml_ctx, out, 2, 3, false)[0];
mean = ggml_cont(ctx->ggml_ctx, mean);
return processor->normalize(ctx, mean);
}
};
} // namespace LTXVAE
struct LTXVideoVAE : public VAE {
bool decode_only;
int ltx_vae_version;
bool timestep_conditioning;
int patch_size;
sd::Tensor<float> decode_timestep_tensor;
LTXVAE::VideoVAE vae;
LTXVideoVAE(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map,
const std::string& prefix,
bool decode_only = true,
SDVersion version = VERSION_LTXAV)
: decode_only(decode_only),
ltx_vae_version(LTXVAE::detect_ltx_vae_version(tensor_storage_map, prefix)),
timestep_conditioning(LTXVAE::detect_ltx_vae_timestep_conditioning(tensor_storage_map, prefix)),
patch_size(4),
decode_timestep_tensor(sd::Tensor<float>::from_vector({0.05f})),
vae(LTXVAE::detect_ltx_vae_version(tensor_storage_map, prefix),
decode_only,
LTXVAE::detect_ltx_vae_timestep_conditioning(tensor_storage_map, prefix),
patch_size,
tensor_storage_map,
prefix),
VAE(version, backend, offload_params_to_cpu) {
vae.init(params_ctx, tensor_storage_map, prefix);
decode_timestep_tensor.values()[0] = vae.decode_timestep;
}
std::string get_desc() override {
return "ltx_video_vae";
}
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) override {
vae.get_param_tensors(tensors, prefix);
}
ggml_cgraph* build_graph(const sd::Tensor<float>& z_tensor, bool decode_graph) {
LOG_DEBUG("ltx_video_vae build_graph input %dx%dx%dx%d",
(int)z_tensor.shape()[0],
(int)z_tensor.shape()[1],
(int)z_tensor.shape()[2],
(int)z_tensor.shape()[3]);
ggml_cgraph* gf = ggml_new_graph(compute_ctx);
ggml_tensor* z = make_input(z_tensor);
ggml_tensor* timestep = nullptr;
if (timestep_conditioning) {
timestep = make_input(decode_timestep_tensor);
}
auto runner_ctx = get_context();
ggml_tensor* out = decode_graph ? vae.decode(&runner_ctx, z, timestep) : vae.encode(&runner_ctx, z);
LOG_DEBUG("ltx_video_vae build_graph output ne=[%lld,%lld,%lld,%lld]",
(long long)out->ne[0],
(long long)out->ne[1],
(long long)out->ne[2],
(long long)out->ne[3]);
ggml_build_forward_expand(gf, out);
return gf;
}
sd::Tensor<float> _compute(const int n_threads,
const sd::Tensor<float>& z,
bool decode_graph) override {
if (!decode_graph && decode_only) {
LOG_ERROR("LTX video VAE encoder is not implemented yet");
return {};
}
sd::Tensor<float> input = z;
size_t expected_dim = static_cast<size_t>(z.dim());
if (!decode_graph) {
if (input.dim() == 4) {
input = input.unsqueeze(2);
expected_dim = 5;
} else if (input.dim() != 5) {
LOG_ERROR("LTX video VAE encoder expects 4D image or 5D video input, got dim=%lld",
(long long)input.dim());
return {};
}
int64_t cropped_t = std::max<int64_t>(1, 1 + ((input.shape()[2] - 1) / 8) * 8);
if (cropped_t != input.shape()[2]) {
input = sd::ops::slice(input, 2, 0, cropped_t);
}
}
auto get_graph = [&]() -> ggml_cgraph* {
return build_graph(input, decode_graph);
};
auto result = restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), expected_dim);
if (result.empty()) {
return {};
}
LOG_DEBUG("ltx_video_vae host output shape=[%lld,%lld,%lld,%lld] dim=%lld",
(long long)(result.shape().size() > 0 ? result.shape()[0] : 0),
(long long)(result.shape().size() > 1 ? result.shape()[1] : 0),
(long long)(result.shape().size() > 2 ? result.shape()[2] : 0),
(long long)(result.shape().size() > 3 ? result.shape()[3] : 0),
(long long)result.dim());
return result;
}
int get_encoder_output_channels(int input_channels) override {
SD_UNUSED(input_channels);
return 256;
}
sd::Tensor<float> vae_output_to_latents(const sd::Tensor<float>& vae_output, std::shared_ptr<RNG> rng) override {
SD_UNUSED(rng);
if (vae_output.dim() >= 4 && vae_output.shape()[3] > 128) {
return sd::ops::slice(vae_output, 3, 0, 128);
}
return vae_output;
}
sd::Tensor<float> diffusion_to_vae_latents(const sd::Tensor<float>& latents) override {
return latents;
}
sd::Tensor<float> vae_to_diffusion_latents(const sd::Tensor<float>& latents) override {
return latents;
}
void test(const std::string& input_path) {
auto z = sd::load_tensor_from_file_as_tensor<float>(input_path);
print_sd_tensor(z, false, "ltx_vae_z");
z = diffusion_to_vae_latents(z);
int64_t t0 = ggml_time_ms();
auto out = _compute(8, z, true);
int64_t t1 = ggml_time_ms();
GGML_ASSERT(!out.empty());
print_sd_tensor(out, false, "ltx_vae_out");
LOG_DEBUG("ltx vae test done in %lldms", t1 - t0);
}
static void load_from_file_and_test(const std::string& model_path,
const std::string& input_path) {
ggml_backend_t backend = ggml_backend_cuda_init(0);
LOG_INFO("loading ltx vae from '%s'", model_path.c_str());
ModelLoader model_loader;
if (!model_loader.init_from_file_and_convert_name(model_path, "vae.")) {
LOG_ERROR("init model loader from file failed: '%s'", model_path.c_str());
return;
}
auto& tensor_storage_map = model_loader.get_tensor_storage_map();
std::shared_ptr<LTXVideoVAE> vae = std::make_shared<LTXVideoVAE>(backend,
false,
tensor_storage_map,
"first_stage_model",
true,
VERSION_LTXAV);
vae->alloc_params_buffer();
std::map<std::string, ggml_tensor*> tensors;
vae->get_param_tensors(tensors, "first_stage_model");
if (!model_loader.load_tensors(tensors)) {
LOG_ERROR("load tensors from model loader failed");
return;
}
LOG_INFO("ltx vae model loaded");
vae->test(input_path);
}
};
#endif // __SD_LTX_VAE_H__

8
src/ltx_vae_test.cpp Normal file
View File

@ -0,0 +1,8 @@
#include "ltx_vae_test.h"
#include "ltx_vae.h"
void ltx_vae_load_from_file_and_test(const std::string& model_path,
const std::string& input_path) {
LTXVideoVAE::load_from_file_and_test(model_path, input_path);
}

9
src/ltx_vae_test.h Normal file
View File

@ -0,0 +1,9 @@
#ifndef __SD_LTX_VAE_TEST_H__
#define __SD_LTX_VAE_TEST_H__
#include <string>
void ltx_vae_load_from_file_and_test(const std::string& model_path,
const std::string& input_path);
#endif // __SD_LTX_VAE_TEST_H__

File diff suppressed because it is too large Load Diff

View File

@ -471,6 +471,9 @@ SDVersion ModelLoader::get_sd_version() {
if (tensor_storage.name.find("model.diffusion_model.layers.0.adaLN_sa_ln.weight") != std::string::npos) {
return VERSION_ERNIE_IMAGE;
}
if (tensor_storage.name.find("model.diffusion_model.adaln_single.emb.timestep_embedder.linear_1.bias") != std::string::npos) {
return VERSION_LTXAV;
}
if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) {
is_wan = true;
}

View File

@ -42,6 +42,7 @@ enum SDVersion {
VERSION_ANIMA,
VERSION_FLUX2,
VERSION_FLUX2_KLEIN,
VERSION_LTXAV,
VERSION_Z_IMAGE,
VERSION_OVIS_IMAGE,
VERSION_ERNIE_IMAGE,
@ -104,6 +105,13 @@ static inline bool sd_version_is_flux2(SDVersion version) {
return false;
}
static inline bool sd_version_is_ltxav(SDVersion version) {
if (version == VERSION_LTXAV) {
return true;
}
return false;
}
static inline bool sd_version_is_wan(SDVersion version) {
if (version == VERSION_WAN2 || version == VERSION_WAN2_2_I2V || version == VERSION_WAN2_2_TI2V) {
return true;
@ -160,6 +168,7 @@ static inline bool sd_version_is_inpaint(SDVersion version) {
static inline bool sd_version_is_dit(SDVersion version) {
if (sd_version_is_flux(version) ||
sd_version_is_flux2(version) ||
sd_version_is_ltxav(version) ||
sd_version_is_sd3(version) ||
sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) ||

View File

@ -14,6 +14,7 @@
#include "diffusion_model.hpp"
#include "esrgan.hpp"
#include "lora.hpp"
#include "ltx_vae.h"
#include "pmid.hpp"
#include "sample-cache.h"
#include "tae.hpp"
@ -52,6 +53,7 @@ const char* model_version_to_str[] = {
"Anima",
"Flux.2",
"Flux.2 klein",
"LTXAV",
"Z-Image",
"Ovis Image",
"Ernie Image",
@ -351,6 +353,17 @@ public:
return false;
}
if (strlen(SAFE_STR(sd_ctx_params->embeddings_connectors_path)) > 0) {
if (sd_version_is_ltxav(version)) {
LOG_INFO("loading embeddings connectors from '%s'", sd_ctx_params->embeddings_connectors_path);
if (!model_loader.init_from_file(sd_ctx_params->embeddings_connectors_path)) {
LOG_WARN("loading embeddings connectors from '%s' failed", sd_ctx_params->embeddings_connectors_path);
}
} else {
LOG_WARN("ignoring embeddings connectors for non-LTXAV model: '%s'", sd_ctx_params->embeddings_connectors_path);
}
}
auto& tensor_storage_map = model_loader.get_tensor_storage_map();
LOG_INFO("Version: %s ", model_version_to_str[version]);
@ -415,6 +428,9 @@ public:
// Might need vae encode for control cond
vae_decode_only = false;
}
if (sd_version_is_ltxav(version)) {
vae_decode_only = true;
}
bool tae_preview_only = sd_ctx_params->tae_preview_only;
if (version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) {
@ -492,6 +508,14 @@ public:
tensor_storage_map,
version,
sd_ctx_params->chroma_use_dit_mask);
} else if (sd_version_is_ltxav(version)) {
cond_stage_model = std::make_shared<LTXAVEmbedder>(clip_backend,
offload_params_to_cpu,
tensor_storage_map);
diffusion_model = std::make_shared<LTXAVModel>(backend,
offload_params_to_cpu,
tensor_storage_map,
"model.diffusion_model");
} else if (sd_version_is_wan(version)) {
cond_stage_model = std::make_shared<T5CLIPEmbedder>(clip_backend,
offload_params_to_cpu,
@ -638,7 +662,14 @@ public:
};
auto create_vae = [&]() -> std::shared_ptr<VAE> {
if (sd_version_is_wan(version) ||
if (sd_version_is_ltxav(version)) {
return std::make_shared<LTXVideoVAE>(vae_backend,
offload_params_to_cpu,
tensor_storage_map,
"first_stage_model",
true,
version);
} else if (sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) ||
sd_version_is_anima(version)) {
return std::make_shared<WAN::WanVAERunner>(vae_backend,
@ -936,13 +967,16 @@ public:
pred_type = EPS_PRED;
}
} else if (sd_version_is_sd3(version) ||
sd_version_is_ltxav(version) ||
sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) ||
sd_version_is_anima(version) ||
sd_version_is_ernie_image(version) ||
sd_version_is_z_image(version)) {
pred_type = FLOW_PRED;
if (sd_version_is_wan(version)) {
if (sd_version_is_ltxav(version)) {
default_flow_shift = 2.37f;
} else if (sd_version_is_wan(version)) {
default_flow_shift = 5.f;
} else if (sd_version_is_ernie_image(version)) {
default_flow_shift = 4.f;
@ -979,8 +1013,13 @@ public:
denoiser = std::make_shared<EDMVDenoiser>();
break;
case FLOW_PRED: {
if (sd_version_is_ltxav(version)) {
LOG_INFO("running in LTXAV FLOW mode");
denoiser = std::make_shared<FluxFlowDenoiser>();
} else {
LOG_INFO("running in FLOW mode");
denoiser = std::make_shared<DiscreteFlowDenoiser>();
}
break;
}
case FLUX_FLOW_PRED: {
@ -1621,6 +1660,7 @@ public:
const sd::Tensor<float>& denoise_mask,
const sd::Tensor<float>& vace_context,
float vace_strength,
int audio_length,
const sd_cache_params_t* cache_params) {
std::vector<int> skip_layers(guidance.slg.layers, guidance.slg.layers + guidance.slg.layer_count);
float cfg_scale = guidance.txt_cfg;
@ -1699,6 +1739,7 @@ public:
diffusion_params.control_strength = control_strength;
diffusion_params.vace_context = vace_context.empty() ? nullptr : &vace_context;
diffusion_params.vace_strength = vace_strength;
diffusion_params.audio_length = audio_length;
diffusion_params.skip_layers = nullptr;
compute_sample_controls(control_image,
@ -1860,7 +1901,9 @@ public:
int get_latent_channel() {
int latent_channel = 4;
if (sd_version_is_dit(version)) {
if (version == VERSION_WAN2_2_TI2V) {
if (sd_version_is_ltxav(version)) {
latent_channel = 128;
} else if (version == VERSION_WAN2_2_TI2V) {
latent_channel = 48;
} else if (version == VERSION_CHROMA_RADIANCE) {
latent_channel = 3;
@ -1886,7 +1929,9 @@ public:
int W = width / vae_scale_factor;
int H = height / vae_scale_factor;
int T = frames;
if (sd_version_is_wan(version)) {
if (sd_version_is_ltxav(version)) {
T = ((T - 1) / 8) + 1;
} else if (sd_version_is_wan(version)) {
T = ((T - 1) / 4) + 1;
}
int C = get_latent_channel();
@ -2223,6 +2268,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"llm_vision_path: %s\n"
"diffusion_model_path: %s\n"
"high_noise_diffusion_model_path: %s\n"
"embeddings_connectors_path: %s\n"
"vae_path: %s\n"
"taesd_path: %s\n"
"control_net_path: %s\n"
@ -2255,6 +2301,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
SAFE_STR(sd_ctx_params->llm_vision_path),
SAFE_STR(sd_ctx_params->diffusion_model_path),
SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path),
SAFE_STR(sd_ctx_params->embeddings_connectors_path),
SAFE_STR(sd_ctx_params->vae_path),
SAFE_STR(sd_ctx_params->taesd_path),
SAFE_STR(sd_ctx_params->control_net_path),
@ -2433,6 +2480,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
sd_vid_gen_params->strength = 0.75f;
sd_vid_gen_params->seed = -1;
sd_vid_gen_params->video_frames = 6;
sd_vid_gen_params->fps = 16;
sd_vid_gen_params->moe_boundary = 0.875f;
sd_vid_gen_params->vace_strength = 1.f;
sd_vid_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
@ -2444,7 +2492,7 @@ struct sd_ctx_t {
};
static bool sd_version_supports_video_generation(SDVersion version) {
return version == VERSION_SVD || sd_version_is_wan(version);
return version == VERSION_SVD || sd_version_is_wan(version) || sd_version_is_ltxav(version);
}
static bool sd_version_supports_image_generation(SDVersion version) {
@ -2589,6 +2637,8 @@ struct GenerationRequest {
sd_pm_params_t pm_params = {};
sd_hires_params_t hires = {};
int frames = -1;
int requested_frames = -1;
int fps = 16;
float vace_strength = 1.f;
GenerationRequest(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params) {
@ -2619,8 +2669,14 @@ struct GenerationRequest {
negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt);
width = sd_vid_gen_params->width;
height = sd_vid_gen_params->height;
frames = (sd_vid_gen_params->video_frames - 1) / 4 * 4 + 1;
requested_frames = std::max(1, sd_vid_gen_params->video_frames);
if (sd_version_is_ltxav(sd_ctx->sd->version)) {
frames = ((requested_frames - 1 + 7) / 8) * 8 + 1;
} else {
frames = (requested_frames - 1) / 4 * 4 + 1;
}
clip_skip = sd_vid_gen_params->clip_skip;
fps = std::max(1, sd_vid_gen_params->fps);
vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
diffusion_model_down_factor = sd_ctx->sd->get_diffusion_model_down_factor();
seed = sd_vid_gen_params->seed;
@ -2629,6 +2685,12 @@ struct GenerationRequest {
guidance = sd_vid_gen_params->sample_params.guidance;
high_noise_guidance = sd_vid_gen_params->high_noise_sample_params.guidance;
resolve(sd_ctx);
if (frames != requested_frames) {
LOG_WARN("align video frames from %d to %d for %s",
requested_frames,
frames,
model_version_to_str[sd_ctx->sd->version]);
}
}
void align_generation_request_size() {
@ -2858,6 +2920,7 @@ struct ImageGenerationLatents {
sd::Tensor<float> init_latent;
sd::Tensor<float> concat_latent;
sd::Tensor<float> uncond_concat_latent;
sd::Tensor<float> audio_latent;
sd::Tensor<float> control_image;
std::vector<sd::Tensor<float>> ref_images;
std::vector<sd::Tensor<float>> ref_latents;
@ -2865,8 +2928,51 @@ struct ImageGenerationLatents {
sd::Tensor<float> clip_vision_output;
sd::Tensor<float> vace_context;
int64_t ref_image_num = 0;
int audio_length = 0;
};
static sd::Tensor<float> pack_ltxav_audio_and_video_latents(const sd::Tensor<float>& video_latent,
const sd::Tensor<float>& audio_latent) {
if (audio_latent.empty()) {
return video_latent;
}
GGML_ASSERT(video_latent.dim() == 4 || video_latent.dim() == 5);
GGML_ASSERT(audio_latent.dim() == 3 || audio_latent.dim() == 4);
if (video_latent.dim() == 5) {
GGML_ASSERT(video_latent.shape()[4] == 1);
}
if (audio_latent.dim() == 4) {
GGML_ASSERT(audio_latent.shape()[3] == 1);
}
int64_t width = video_latent.shape()[0];
int64_t height = video_latent.shape()[1];
int64_t frames = video_latent.shape()[2];
int64_t video_ch = video_latent.shape()[3];
int64_t spatial_size = width * height * frames;
int64_t audio_values = audio_latent.numel();
int64_t extra_ch = (audio_values + spatial_size - 1) / spatial_size;
std::vector<int64_t> packed_shape = video_latent.shape();
packed_shape[3] = video_ch + extra_ch;
sd::Tensor<float> packed = sd::zeros<float>(packed_shape);
std::copy_n(video_latent.data(), video_latent.numel(), packed.data());
std::copy_n(audio_latent.data(), audio_latent.numel(), packed.data() + video_latent.numel());
return packed;
}
static int get_ltxav_num_audio_latents(int frames, int fps) {
GGML_ASSERT(frames > 0);
GGML_ASSERT(fps > 0);
constexpr float kSampleRate = 16000.0f;
constexpr float kMelHopLength = 160.0f;
constexpr float kAudioLatentDownsample = 4.0f;
constexpr float kLatentsPerSecond = kSampleRate / kMelHopLength / kAudioLatentDownsample;
return static_cast<int>(std::ceil((static_cast<float>(frames) / static_cast<float>(fps)) * kLatentsPerSecond));
}
struct ImageGenerationEmbeds {
SDCondition cond;
SDCondition uncond;
@ -3454,6 +3560,7 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
latents.denoise_mask,
sd::Tensor<float>(),
1.f,
0,
request.cache_params);
int64_t sampling_end = ggml_time_ms();
if (!x_0.empty()) {
@ -3575,6 +3682,7 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
hires_denoise_mask,
sd::Tensor<float>(),
1.f,
0,
request.cache_params);
int64_t hires_sample_end = ggml_time_ms();
if (!x_0.empty()) {
@ -3633,6 +3741,18 @@ static std::optional<ImageGenerationLatents> prepare_video_generation_latents(sd
end_image = sd_image_to_tensor(sd_vid_gen_params->end_image, request->width, request->height);
}
if (sd_version_is_ltxav(sd_ctx->sd->version)) {
latents.audio_length = get_ltxav_num_audio_latents(request->frames, request->fps);
latents.audio_latent = sd::zeros<float>({16, latents.audio_length, 8, 1});
}
if (sd_version_is_ltxav(sd_ctx->sd->version)) {
if (!start_image.empty() || !end_image.empty() || sd_vid_gen_params->control_frames_size > 0) {
LOG_ERROR("LTXAV currently supports txt2vid only; init_image, end_image, and control_frames are not implemented");
return std::nullopt;
}
}
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" ||
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-I2V-14B" ||
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-1.3B" ||
@ -3803,6 +3923,10 @@ static std::optional<ImageGenerationLatents> prepare_video_generation_latents(sd
latents.init_latent = sd_ctx->sd->generate_init_latent(request->width, request->height, request->frames, true);
}
if (!latents.audio_latent.empty()) {
latents.init_latent = pack_ltxav_audio_and_video_latents(latents.init_latent, latents.audio_latent);
}
return latents;
}
@ -3839,14 +3963,26 @@ static ImageGenerationEmbeds prepare_video_generation_embeds(sd_ctx_t* sd_ctx,
}
static sd_image_t* decode_video_outputs(sd_ctx_t* sd_ctx,
const GenerationRequest& request,
const sd::Tensor<float>& final_latent,
int* num_frames_out) {
if (final_latent.empty()) {
LOG_ERROR("no latent video to decode");
return nullptr;
}
sd::Tensor<float> video_latent = final_latent;
if (sd_version_is_ltxav(sd_ctx->sd->version) &&
video_latent.shape()[3] > sd_ctx->sd->get_latent_channel()) {
video_latent = sd::ops::slice(video_latent, 3, 0, sd_ctx->sd->get_latent_channel());
}
LOG_DEBUG("decode_video_outputs latent %dx%dx%dx%d",
(int)video_latent.shape()[0],
(int)video_latent.shape()[1],
(int)video_latent.shape()[2],
(int)video_latent.shape()[3]);
// auto z = sd::load_tensor_from_file_as_tensor<float>("ltx_vae_z.bin");
int64_t t4 = ggml_time_ms();
sd::Tensor<float> vid = sd_ctx->sd->decode_first_stage(final_latent, true);
sd::Tensor<float> vid = sd_ctx->sd->decode_first_stage(video_latent, true);
int64_t t5 = ggml_time_ms();
LOG_INFO("decode_first_stage completed, taking %.2fs", (t5 - t4) * 1.0f / 1000);
if (sd_ctx->sd->free_params_immediately) {
@ -3856,6 +3992,15 @@ static sd_image_t* decode_video_outputs(sd_ctx_t* sd_ctx,
LOG_ERROR("decode_first_stage failed for video");
return nullptr;
}
LOG_DEBUG("decode_video_outputs decoded %dx%dx%dx%d",
(int)vid.shape()[0],
(int)vid.shape()[1],
(int)vid.shape()[2],
(int)vid.shape()[3]);
if (request.requested_frames > 0 &&
vid.shape()[2] > request.requested_frames) {
vid = sd::ops::slice(vid, 2, 0, request.requested_frames);
}
sd_image_t* result_images = (sd_image_t*)calloc(vid.shape()[2], sizeof(sd_image_t));
if (result_images == nullptr) {
@ -3939,6 +4084,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
latents.denoise_mask,
latents.vace_context,
request.vace_strength,
latents.audio_length,
request.cache_params);
int64_t sampling_end = ggml_time_ms();
if (x_t_sampled.empty()) {
@ -3981,6 +4127,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
latents.denoise_mask,
latents.vace_context,
request.vace_strength,
latents.audio_length,
request.cache_params);
int64_t sampling_end = ggml_time_ms();
@ -4000,7 +4147,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
int64_t latent_end = ggml_time_ms();
LOG_INFO("generating latent video completed, taking %.2fs", (latent_end - latent_start) * 1.0f / 1000);
auto result = decode_video_outputs(sd_ctx, final_latent, num_frames_out);
auto result = decode_video_outputs(sd_ctx, request, final_latent, num_frames_out);
if (result == nullptr) {
return nullptr;
}

View File

@ -2,7 +2,6 @@
#define __TAE_HPP__
#include "ggml_extend.hpp"
#include "model.h"
/*

View File

@ -104,7 +104,7 @@ namespace sd {
throw std::invalid_argument("tensor file type does not match requested sd::Tensor type");
}
std::vector<int64_t> shape(4, 1);
std::vector<int64_t> shape(n_dims, 1);
for (int i = 0; i < n_dims; ++i) {
int32_t dim = 1;
file.read(reinterpret_cast<char*>(&dim), sizeof(dim));

View File

@ -50,6 +50,7 @@ GemmaTokenizer::GemmaTokenizer(const std::string& merges_utf8_str, const std::st
byte_level_bpe = false;
byte_fallback = true;
add_bos_token = true;
pad_left = true;
PAD_TOKEN = "<pad>";
EOS_TOKEN = "<eos>";
BOS_TOKEN = "<bos>";

View File

@ -67,7 +67,9 @@ public:
int get_scale_factor() {
int scale_factor = 8;
if (version == VERSION_WAN2_2_TI2V) {
if (version == VERSION_LTXAV) {
scale_factor = 32;
} else if (version == VERSION_WAN2_2_TI2V) {
scale_factor = 16;
} else if (sd_version_uses_flux2_vae(version)) {
scale_factor = 16;

View File

@ -966,7 +966,7 @@ namespace WAN {
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new CausalConv3d(z_dim, z_dim, {1, 1, 1}));
}
ggml_tensor* patchify(ggml_context* ctx,
static ggml_tensor* patchify(ggml_context* ctx,
ggml_tensor* x,
int64_t patch_size,
int64_t b = 1) {
@ -993,7 +993,7 @@ namespace WAN {
return x;
}
ggml_tensor* unpatchify(ggml_context* ctx,
static ggml_tensor* unpatchify(ggml_context* ctx,
ggml_tensor* x,
int64_t patch_size,
int64_t b = 1) {