From 831b321c6a7f98103a70fa8d50f22cfd0283951e Mon Sep 17 00:00:00 2001 From: leejet Date: Wed, 29 Apr 2026 01:01:58 +0800 Subject: [PATCH] add basic ltx2.3 support --- examples/cli/main.cpp | 10 - examples/common/common.cpp | 7 + examples/common/common.h | 1 + include/stable-diffusion.h | 2 + src/common_dit.hpp | 58 + src/conditioner.hpp | 288 ++++- src/diffusion_model.hpp | 69 + src/ggml_extend.hpp | 61 +- src/llm.hpp | 356 ++++- src/ltx_vae.hpp | 971 ++++++++++++++ src/ltxv.hpp | 1925 +++++++++++++++++++++++++++- src/model.cpp | 3 + src/model.h | 9 + src/stable-diffusion.cpp | 178 ++- src/tae.hpp | 1 - src/tensor_ggml.hpp | 2 +- src/tokenizers/gemma_tokenizer.cpp | 1 + src/vae.hpp | 4 +- src/wan.hpp | 16 +- 19 files changed, 3807 insertions(+), 155 deletions(-) create mode 100644 src/ltx_vae.hpp diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 15b04d8f..8cec2dbc 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -19,7 +19,6 @@ #include "common/media_io.h" #include "common/resource_owners.hpp" #include "image_metadata.h" -#include "llm.hpp" namespace fs = std::filesystem; @@ -501,15 +500,6 @@ int main(int argc, const char* argv[]) { SDContextParams ctx_params; SDGenerationParams gen_params; - cli_params.verbose = true; - sd_set_log_callback(sd_log_cb, (void*)&cli_params); - GemmaTokenizer tokenizer; - auto tokens = tokenizer.tokenize(" 一只可爱的小猫"); - for (auto token : tokens) { - LOG_INFO("%d", token); - } - return 0; - parse_args(argc, argv, cli_params, ctx_params, gen_params); sd_set_log_callback(sd_log_cb, (void*)&cli_params); log_verbose = cli_params.verbose; diff --git a/examples/common/common.cpp b/examples/common/common.cpp index 2d29df26..80a37246 100644 --- a/examples/common/common.cpp +++ b/examples/common/common.cpp @@ -340,6 +340,10 @@ ArgOptions SDContextParams::get_options() { "--high-noise-diffusion-model", "path to the standalone high noise diffusion model", &high_noise_diffusion_model_path}, + {"", + "--embeddings-connectors", + "path to LTXAV embeddings connectors", + &embeddings_connectors_path}, {"", "--vae", "path to standalone vae model", @@ -656,6 +660,7 @@ std::string SDContextParams::to_string() const { << " llm_vision_path: \"" << llm_vision_path << "\",\n" << " diffusion_model_path: \"" << diffusion_model_path << "\",\n" << " high_noise_diffusion_model_path: \"" << high_noise_diffusion_model_path << "\",\n" + << " embeddings_connectors_path: \"" << embeddings_connectors_path << "\",\n" << " vae_path: \"" << vae_path << "\",\n" << " taesd_path: \"" << taesd_path << "\",\n" << " esrgan_path: \"" << esrgan_path << "\",\n" @@ -712,6 +717,7 @@ sd_ctx_params_t SDContextParams::to_sd_ctx_params_t(bool vae_decode_only, bool f llm_vision_path.c_str(), diffusion_model_path.c_str(), high_noise_diffusion_model_path.c_str(), + embeddings_connectors_path.c_str(), vae_path.c_str(), taesd_path.c_str(), control_net_path.c_str(), @@ -2180,6 +2186,7 @@ sd_vid_gen_params_t SDGenerationParams::to_sd_vid_gen_params_t() { params.strength = strength; params.seed = seed; params.video_frames = video_frames; + params.fps = fps; params.vace_strength = vace_strength; params.vae_tiling_params = vae_tiling_params; params.cache = cache_params; diff --git a/examples/common/common.h b/examples/common/common.h index 333d3311..bdd8246d 100644 --- a/examples/common/common.h +++ b/examples/common/common.h @@ -92,6 +92,7 @@ struct SDContextParams { std::string llm_vision_path; std::string diffusion_model_path; std::string high_noise_diffusion_model_path; + std::string embeddings_connectors_path; std::string vae_path; std::string taesd_path; std::string esrgan_path; diff --git a/include/stable-diffusion.h b/include/stable-diffusion.h index 75027f8f..14c26a3f 100644 --- a/include/stable-diffusion.h +++ b/include/stable-diffusion.h @@ -171,6 +171,7 @@ typedef struct { const char* llm_vision_path; const char* diffusion_model_path; const char* high_noise_diffusion_model_path; + const char* embeddings_connectors_path; const char* vae_path; const char* taesd_path; const char* control_net_path; @@ -359,6 +360,7 @@ typedef struct { float strength; int64_t seed; int video_frames; + int fps; float vace_strength; sd_tiling_params_t vae_tiling_params; sd_cache_params_t cache; diff --git a/src/common_dit.hpp b/src/common_dit.hpp index 30141d42..c25302ff 100644 --- a/src/common_dit.hpp +++ b/src/common_dit.hpp @@ -103,6 +103,64 @@ namespace DiT { x = ggml_ext_slice(ctx, x, 0, 0, W); // [N, C, H, W] return x; } + + inline ggml_tensor* patchify(ggml_context* ctx, + ggml_tensor* x, + int pt, + int ph, + int pw, + int64_t N = 1) { + // x: [N*C, T, H, W] + // return: [N, h*w, C*pt*ph*pw] + int64_t C = x->ne[3] / N; + int64_t T = x->ne[2]; + int64_t H = x->ne[1]; + int64_t W = x->ne[0]; + int64_t t_len = T / pt; + int64_t h_len = H / ph; + int64_t w_len = W / pw; + + GGML_ASSERT(C * N == x->ne[3]); + GGML_ASSERT(t_len * pt == T && h_len * ph == H && w_len * pw == W); + + x = ggml_reshape_4d(ctx, x, pw * w_len, ph * h_len, pt, t_len * C * N); // [N*C*t_len, pt, h_len*ph, w_len*pw] + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len, h_len*ph, pt, w_len*pw] + x = ggml_reshape_4d(ctx, x, pw * w_len, pt, ph, h_len * t_len * C * N); // [N*C*t_len*h_len, ph, pt, w_len*pw] + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, pt, ph, w_len*pw] + x = ggml_reshape_4d(ctx, x, pw, w_len, ph * pt, h_len * t_len * C * N); // [N*C*t_len*h_len, pt*ph, w_len, pw] + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, w_len, pt*ph, pw] + x = ggml_reshape_4d(ctx, x, pw * ph * pt, w_len * h_len * t_len, C, N); // [N, C, t_len*h_len*w_len, pt*ph*pw] + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [N, t_len*h_len*w_len, C, pt*ph*pw] + x = ggml_reshape_4d(ctx, x, pw * ph * pt * C, w_len * h_len * t_len, N, 1); // [N, t_len*h_len*w_len, C*pt*ph*pw] + return x; + } + + inline ggml_tensor* unpatchify(ggml_context* ctx, + ggml_tensor* x, + int64_t t_len, + int64_t h_len, + int64_t w_len, + int pt, + int ph, + int pw) { + // x: [N, t_len*h_len*w_len, pt*ph*pw*C] + // return: [N*C, t_len*pt, h_len*ph, w_len*pw] + int64_t N = x->ne[3]; + int64_t C = x->ne[0] / pt / ph / pw; + + GGML_ASSERT(C * pt * ph * pw == x->ne[0]); + + x = ggml_reshape_4d(ctx, x, C, pw * ph * pt, w_len * h_len * t_len, N); // [N, t_len*h_len*w_len, pt*ph*pw, C] + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, t_len*h_len*w_len, pt*ph*pw] + x = ggml_reshape_4d(ctx, x, pw, ph * pt, w_len, h_len * t_len * C * N); // [N*C*t_len*h_len, w_len, pt*ph, pw] + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, pt*ph, w_len, pw] + x = ggml_reshape_4d(ctx, x, pw * w_len, ph, pt, h_len * t_len * C * N); // [N*C*t_len*h_len, pt, ph, w_len*pw] + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len*h_len, ph, pt, w_len*pw] + x = ggml_reshape_4d(ctx, x, pw * w_len, pt, ph * h_len, t_len * C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw] + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [N*C*t_len, pt, h_len*ph, w_len*pw] + x = ggml_reshape_4d(ctx, x, pw * w_len, ph * h_len, pt * t_len, C * N); // [N*C, t_len*pt, h_len*ph, w_len*pw] + return x; + } } // namespace DiT #endif // __COMMON_DIT_HPP__ diff --git a/src/conditioner.hpp b/src/conditioner.hpp index 9f4d4552..1b4c32e2 100644 --- a/src/conditioner.hpp +++ b/src/conditioner.hpp @@ -1,6 +1,8 @@ #ifndef __CONDITIONER_HPP__ #define __CONDITIONER_HPP__ +#include +#include #include #include "clip.hpp" @@ -46,6 +48,17 @@ static inline sd::Tensor apply_token_weights(sd::Tensor hidden_sta return hidden_states; } + bool all_one = true; + for (float weight : weights) { + if (weight != 1.0f) { + all_one = false; + break; + } + } + if (all_one) { + return hidden_states; + } + if (hidden_states.dim() == 1) { hidden_states.unsqueeze_(1); } @@ -57,7 +70,7 @@ static inline sd::Tensor apply_token_weights(sd::Tensor hidden_sta chunk_weights.reshape_({1, static_cast(weights.size())}); hidden_states *= chunk_weights; float new_mean = hidden_states.mean(); - if (new_mean != 0.0f) { + if (std::isfinite(original_mean) && std::isfinite(new_mean) && new_mean != 0.0f) { hidden_states *= (original_mean / new_mean); } @@ -1958,4 +1971,277 @@ struct LLMEmbedder : public Conditioner { } }; +struct LTXAVTextProjection : public GGMLBlock { + static constexpr int64_t kHiddenSize = 3840; + static constexpr int64_t kNumStates = 49; + bool dual_projection = false; + + LTXAVTextProjection(bool dual_projection = false) + : dual_projection(dual_projection) { + if (dual_projection) { + blocks["video_aggregate_embed"] = std::make_shared(kHiddenSize * kNumStates, 4096, true); + blocks["audio_aggregate_embed"] = std::make_shared(kHiddenSize * kNumStates, 2048, true); + } else { + blocks["projection"] = std::make_shared(kHiddenSize * kNumStates, kHiddenSize, false); + } + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + if (!dual_projection) { + auto projection = std::dynamic_pointer_cast(blocks["projection"]); + return projection->forward(ctx, x); + } + + auto video_projection = std::dynamic_pointer_cast(blocks["video_aggregate_embed"]); + auto audio_projection = std::dynamic_pointer_cast(blocks["audio_aggregate_embed"]); + auto video_in = ggml_ext_scale(ctx->ggml_ctx, x, std::sqrt(4096.f / static_cast(kHiddenSize))); + auto audio_in = ggml_ext_scale(ctx->ggml_ctx, x, std::sqrt(2048.f / static_cast(kHiddenSize))); + auto video = video_projection->forward(ctx, video_in); + auto audio = audio_projection->forward(ctx, audio_in); + return ggml_concat(ctx->ggml_ctx, video, audio, 0); + } +}; + +struct LTXAVTextProjectionRunner : public GGMLRunner { + LTXAVTextProjection model; + + LTXAVTextProjectionRunner(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map = {}, + const std::string& prefix = "") + : GGMLRunner(backend, offload_params_to_cpu), + model(tensor_storage_map.find(prefix + ".video_aggregate_embed.weight") != tensor_storage_map.end()) { + model.init(params_ctx, tensor_storage_map, prefix); + } + + std::string get_desc() override { + return "ltxav_text_projection"; + } + + void get_param_tensors(std::map& tensors, const std::string& prefix) { + model.get_param_tensors(tensors, prefix); + } + + ggml_cgraph* build_graph(const sd::Tensor& x_tensor) { + ggml_cgraph* gf = ggml_new_graph(compute_ctx); + auto x = make_input(x_tensor); + auto runner_ctx = get_context(); + auto out = model.forward(&runner_ctx, x); + ggml_build_forward_expand(gf, out); + return gf; + } + + sd::Tensor compute(int n_threads, const sd::Tensor& x) { + auto get_graph = [&]() -> ggml_cgraph* { + return build_graph(x); + }; + return take_or_empty(GGMLRunner::compute(get_graph, n_threads, true)); + } +}; + +struct LTXAVEmbedder : public Conditioner { + static constexpr int64_t kHiddenSize = 3840; + static constexpr int64_t kNumStates = 49; + static constexpr int64_t kMinLength = 1024; + + std::shared_ptr tokenizer; + std::shared_ptr llm; + std::shared_ptr projector; + bool dual_projection = false; + + LTXAVEmbedder(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map = {}, + const std::string& llm_prefix = "text_encoders.llm", + const std::string& projector_prefix = "text_embedding_projection") { + tokenizer = std::make_shared(); + llm = std::make_shared(LLM::LLMArch::GEMMA3_12B, + backend, + offload_params_to_cpu, + tensor_storage_map, + llm_prefix, + false); + dual_projection = tensor_storage_map.find(projector_prefix + ".video_aggregate_embed.weight") != tensor_storage_map.end(); + projector = std::make_shared(backend, + offload_params_to_cpu, + tensor_storage_map, + projector_prefix); + } + + void get_param_tensors(std::map& tensors) override { + llm->get_param_tensors(tensors, "text_encoders.llm"); + projector->get_param_tensors(tensors, "text_embedding_projection"); + } + + void alloc_params_buffer() override { + llm->alloc_params_buffer(); + projector->alloc_params_buffer(); + } + + void free_params_buffer() override { + llm->free_params_buffer(); + projector->free_params_buffer(); + } + + size_t get_params_buffer_size() override { + return llm->get_params_buffer_size() + projector->get_params_buffer_size(); + } + + void set_flash_attention_enabled(bool enabled) override { + llm->set_flash_attention_enabled(enabled); + projector->set_flash_attention_enabled(enabled); + } + + void set_weight_adapter(const std::shared_ptr& adapter) override { + llm->set_weight_adapter(adapter); + projector->set_weight_adapter(adapter); + } + + std::tuple, std::vector, std::vector> tokenize(std::string text, + const std::pair& attn_range) { + std::vector> parsed_attention; + if (attn_range.first >= 0 && attn_range.second > 0) { + if (attn_range.first > 0) { + parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f); + } + if (attn_range.second - attn_range.first > 0) { + auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first)); + parsed_attention.insert(parsed_attention.end(), new_parsed_attention.begin(), new_parsed_attention.end()); + } + if (static_cast(attn_range.second) < text.size()) { + parsed_attention.emplace_back(text.substr(attn_range.second), 1.f); + } + } else { + parsed_attention.emplace_back(text, 1.f); + } + + std::vector tokens; + std::vector weights; + for (const auto& item : parsed_attention) { + auto curr_tokens = tokenizer->encode(item.first, nullptr); + tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end()); + weights.insert(weights.end(), curr_tokens.size(), item.second); + } + + std::vector mask; + tokenizer->pad_tokens(tokens, &weights, &mask, kMinLength); + return {tokens, weights, mask}; + } + + sd::Tensor encode_prompt(int n_threads, + const std::string& prompt, + const std::pair& prompt_attn_range) { + auto tokens_weights_mask = tokenize(prompt, prompt_attn_range); + auto& tokens = std::get<0>(tokens_weights_mask); + auto& weights = std::get<1>(tokens_weights_mask); + auto& mask = std::get<2>(tokens_weights_mask); + + sd::Tensor input_ids({static_cast(tokens.size())}, std::vector(tokens.begin(), tokens.end())); + sd::Tensor attention_mask; + if (!mask.empty()) { + const float mask_min = std::numeric_limits::lowest() / 4.0f; + attention_mask = sd::Tensor({static_cast(mask.size()), static_cast(mask.size())}); + for (size_t i1 = 0; i1 < mask.size(); ++i1) { + for (size_t i0 = 0; i0 < mask.size(); ++i0) { + float value = 0.0f; + if (mask[i0] == 0.0f) { + value += mask_min; + } + if (i0 > i1) { + value += mask_min; + } + attention_mask[static_cast(i0 + mask.size() * i1)] = value; + } + } + } + + auto hidden_states = llm->compute(n_threads, + input_ids, + attention_mask, + {}, + {}, + true); + GGML_ASSERT(!hidden_states.empty()); + hidden_states = apply_token_weights(std::move(hidden_states), weights); + + int64_t valid_tokens = 0; + for (float value : mask) { + valid_tokens += static_cast(value > 0.0f); + } + GGML_ASSERT(valid_tokens > 0); + + hidden_states = sd::ops::slice(hidden_states, + 1, + hidden_states.shape()[1] - valid_tokens, + hidden_states.shape()[1]); + hidden_states.reshape_({kHiddenSize, kNumStates, valid_tokens}); + hidden_states = hidden_states.permute({1, 0, 2}); + + if (dual_projection) { + for (int64_t state_idx = 0; state_idx < kNumStates; ++state_idx) { + for (int64_t token_idx = 0; token_idx < valid_tokens; ++token_idx) { + double sq_sum = 0.0; + for (int64_t hidden_idx = 0; hidden_idx < kHiddenSize; ++hidden_idx) { + float value = hidden_states.index(state_idx, hidden_idx, token_idx); + sq_sum += static_cast(value) * static_cast(value); + } + + float inv_rms = 1.0f / std::sqrt(static_cast(sq_sum / static_cast(kHiddenSize)) + 1e-6f); + for (int64_t hidden_idx = 0; hidden_idx < kHiddenSize; ++hidden_idx) { + hidden_states.index(state_idx, hidden_idx, token_idx) *= inv_rms; + } + } + } + } else { + for (int64_t state_idx = 0; state_idx < kNumStates; ++state_idx) { + double sum = 0.0; + float min_value = std::numeric_limits::infinity(); + float max_value = -std::numeric_limits::infinity(); + for (int64_t token_idx = 0; token_idx < valid_tokens; ++token_idx) { + for (int64_t hidden_idx = 0; hidden_idx < kHiddenSize; ++hidden_idx) { + float value = hidden_states.index(state_idx, hidden_idx, token_idx); + sum += value; + min_value = std::min(min_value, value); + max_value = std::max(max_value, value); + } + } + + float mean_value = static_cast(sum / static_cast(kHiddenSize * valid_tokens)); + float denom = max_value - min_value + 1e-6f; + float scale_value = 8.0f / denom; + for (int64_t token_idx = 0; token_idx < valid_tokens; ++token_idx) { + for (int64_t hidden_idx = 0; hidden_idx < kHiddenSize; ++hidden_idx) { + float value = hidden_states.index(state_idx, hidden_idx, token_idx); + hidden_states.index(state_idx, hidden_idx, token_idx) = (value - mean_value) * scale_value; + } + } + } + } + + hidden_states.reshape_({kNumStates * kHiddenSize, valid_tokens}); + return projector->compute(n_threads, hidden_states); + } + + SDCondition get_learned_condition(int n_threads, + const ConditionerParams& conditioner_params) override { + int64_t t0 = ggml_time_ms(); + + std::string prompt; + std::pair prompt_attn_range; + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); + + auto hidden_states = encode_prompt(n_threads, prompt, prompt_attn_range); + GGML_ASSERT(!hidden_states.empty()); + + int64_t t1 = ggml_time_ms(); + LOG_DEBUG("computing LTXAV condition graph completed, taking %" PRId64 " ms", t1 - t0); + + SDCondition result; + result.c_crossattn = std::move(hidden_states); + return result; + } +}; + #endif diff --git a/src/diffusion_model.hpp b/src/diffusion_model.hpp index c0a2a11c..3fc4c3a3 100644 --- a/src/diffusion_model.hpp +++ b/src/diffusion_model.hpp @@ -5,6 +5,7 @@ #include "anima.hpp" #include "ernie_image.hpp" #include "flux.hpp" +#include "ltxv.hpp" #include "mmdit.hpp" #include "qwen_image.hpp" #include "tensor_ggml.hpp" @@ -14,7 +15,9 @@ struct DiffusionParams { const sd::Tensor* x = nullptr; + const sd::Tensor* audio_x = nullptr; const sd::Tensor* timesteps = nullptr; + const sd::Tensor* audio_timesteps = nullptr; const sd::Tensor* context = nullptr; const sd::Tensor* c_concat = nullptr; const sd::Tensor* y = nullptr; @@ -28,6 +31,7 @@ struct DiffusionParams { float control_strength = 0.f; const sd::Tensor* vace_context = nullptr; float vace_strength = 1.f; + int audio_length = 0; const std::vector* skip_layers = nullptr; }; @@ -579,4 +583,69 @@ struct ErnieImageModel : public DiffusionModel { } }; +struct LTXAVModel : public DiffusionModel { + std::string prefix; + LTXV::LTXAVRunner ltxav; + + LTXAVModel(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "model.diffusion_model") + : prefix(prefix), ltxav(backend, offload_params_to_cpu, tensor_storage_map, prefix) { + } + + std::string get_desc() override { + return ltxav.get_desc(); + } + + void alloc_params_buffer() override { + ltxav.alloc_params_buffer(); + } + + void free_params_buffer() override { + ltxav.free_params_buffer(); + } + + void free_compute_buffer() override { + ltxav.free_compute_buffer(); + } + + void get_param_tensors(std::map& tensors) override { + ltxav.get_param_tensors(tensors, prefix); + } + + size_t get_params_buffer_size() override { + return ltxav.get_params_buffer_size(); + } + + void set_weight_adapter(const std::shared_ptr& adapter) override { + ltxav.set_weight_adapter(adapter); + } + + int64_t get_adm_in_channels() override { + return 0; + } + + void set_flash_attention_enabled(bool enabled) override { + ltxav.set_flash_attention_enabled(enabled); + } + + void set_circular_axes(bool circular_x, bool circular_y) override { + ltxav.set_circular_axes(circular_x, circular_y); + } + + sd::Tensor compute(int n_threads, + const DiffusionParams& diffusion_params) override { + GGML_ASSERT(diffusion_params.x != nullptr); + GGML_ASSERT(diffusion_params.timesteps != nullptr); + return ltxav.compute(n_threads, + *diffusion_params.x, + *diffusion_params.timesteps, + tensor_or_empty(diffusion_params.context), + tensor_or_empty(diffusion_params.audio_x), + tensor_or_empty(diffusion_params.audio_timesteps), + diffusion_params.audio_length); + } +}; + #endif diff --git a/src/ggml_extend.hpp b/src/ggml_extend.hpp index 859270cb..1053d24d 100644 --- a/src/ggml_extend.hpp +++ b/src/ggml_extend.hpp @@ -1675,13 +1675,22 @@ struct WeightAdapter { }; struct GGMLRunnerContext { - ggml_backend_t backend = nullptr; - ggml_context* ggml_ctx = nullptr; - bool flash_attn_enabled = false; - bool conv2d_direct_enabled = false; - bool circular_x_enabled = false; - bool circular_y_enabled = false; - std::shared_ptr weight_adapter = nullptr; + ggml_backend_t backend = nullptr; + ggml_context* ggml_ctx = nullptr; + bool flash_attn_enabled = false; + bool conv2d_direct_enabled = false; + bool circular_x_enabled = false; + bool circular_y_enabled = false; + std::shared_ptr weight_adapter = nullptr; + std::unordered_map* debug_tensors = nullptr; + + void capture_tensor(const std::string& name, ggml_tensor* tensor) { + if (debug_tensors == nullptr || tensor == nullptr) { + return; + } + ggml_set_output(tensor); + (*debug_tensors)[tensor] = name; + } }; struct GGMLRunner { @@ -1713,6 +1722,7 @@ protected: std::map backend_tensor_data_map; std::map cache_tensor_map; // name -> tensor + std::unordered_map debug_tensors; const std::string final_result_name = "ggml_runner_final_result_tensor"; bool flash_attn_enabled = false; @@ -1799,6 +1809,7 @@ protected: } void free_compute_ctx() { + debug_tensors.clear(); if (compute_ctx != nullptr) { ggml_free(compute_ctx); compute_ctx = nullptr; @@ -1834,6 +1845,11 @@ protected: auto result = ggml_graph_node(gf, -1); ggml_set_name(result, final_result_name.c_str()); } + for (const auto& entry : debug_tensors) { + if (entry.first != nullptr) { + ggml_build_forward_expand(gf, entry.first); + } + } prepare_build_in_tensor_after(gf); return gf; } @@ -1903,6 +1919,21 @@ protected: for (auto& kv : backend_tensor_data_map) { auto tensor = kv.first; auto data = kv.second; + if (tensor == nullptr || data == nullptr) { + continue; + } + const char* name = ggml_get_name(tensor); + if (tensor->buffer == nullptr) { + LOG_WARN("%s skip backend tensor copy: tensor buffer not set, name='%s', ne=[%lld,%lld,%lld,%lld], type=%s", + get_desc().c_str(), + name != nullptr ? name : "", + (long long)tensor->ne[0], + (long long)tensor->ne[1], + (long long)tensor->ne[2], + (long long)tensor->ne[3], + ggml_type_name(tensor->type)); + continue; + } ggml_backend_tensor_set(tensor, data, 0, ggml_nbytes(tensor)); } @@ -2025,6 +2056,7 @@ public: runner_ctx.circular_x_enabled = circular_x_enabled; runner_ctx.circular_y_enabled = circular_y_enabled; runner_ctx.weight_adapter = weight_adapter; + runner_ctx.debug_tensors = &debug_tensors; return runner_ctx; } @@ -2163,6 +2195,21 @@ public: LOG_ERROR("%s compute failed: %s", get_desc().c_str(), ggml_status_to_string(status)); return std::nullopt; } + for (const auto& entry : debug_tensors) { + auto tensor = entry.first; + if (tensor == nullptr) { + continue; + } + if (tensor->type != GGML_TYPE_F32) { + LOG_WARN("%s skip debug tensor '%s': only GGML_TYPE_F32 is supported, got %s", + get_desc().c_str(), + entry.second.c_str(), + ggml_type_name(tensor->type)); + continue; + } + auto debug_tensor = sd::make_sd_tensor_from_ggml(tensor); + print_sd_tensor(debug_tensor, false, entry.second.c_str()); + } copy_cache_tensors_to_cache_buffer(); auto result = ggml_get_tensor(compute_ctx, final_result_name.c_str()); std::optional> output; diff --git a/src/llm.hpp b/src/llm.hpp index 95030385..2f73702f 100644 --- a/src/llm.hpp +++ b/src/llm.hpp @@ -2,8 +2,10 @@ #define __LLM_HPP__ #include +#include #include #include +#include #include #include #include @@ -30,6 +32,7 @@ namespace LLM { QWEN3, MISTRAL_SMALL_3_2, MINISTRAL_3_3B, + GEMMA3_12B, ARCH_COUNT, }; @@ -38,6 +41,12 @@ namespace LLM { "qwen3", "mistral_small3.2", "ministral3.3b", + "gemma3_12b", + }; + + enum class MLPActivation { + SILU, + GELU_TANH, }; struct LLMVisionParams { @@ -55,23 +64,71 @@ namespace LLM { }; struct LLMParams { - LLMArch arch = LLMArch::QWEN2_5_VL; - int64_t num_layers = 28; - int64_t hidden_size = 3584; - int64_t intermediate_size = 18944; - int num_heads = 28; - int num_kv_heads = 4; - int head_dim = 128; - bool qkv_bias = true; - bool qk_norm = false; - int64_t vocab_size = 152064; - float rms_norm_eps = 1e-06f; + LLMArch arch = LLMArch::QWEN2_5_VL; + int64_t num_layers = 28; + int64_t hidden_size = 3584; + int64_t intermediate_size = 18944; + int num_heads = 28; + int num_kv_heads = 4; + int head_dim = 128; + bool qkv_bias = true; + bool qk_norm = false; + bool rms_norm_add = false; + bool normalize_input = false; + int64_t vocab_size = 152064; + int64_t max_position_embeddings = 128000; + float rms_norm_eps = 1e-06f; + MLPActivation mlp_activation = MLPActivation::SILU; + std::vector rope_thetas = {1000000.f}; + std::vector rope_scales = {1.f}; + std::vector sliding_attention; LLMVisionParams vision; }; - struct MLP : public GGMLBlock { + struct LLMRMSNorm : public UnaryBlock { + protected: + int64_t hidden_size; + float eps; + bool add_unit_offset; + std::string prefix; + + void init_params(ggml_context* ctx, + const String2TensorStorage& tensor_storage_map = {}, + std::string prefix = "") override { + this->prefix = prefix; + params["weight"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size); + } + public: - MLP(int64_t hidden_size, int64_t intermediate_size, bool bias = false) { + LLMRMSNorm(int64_t hidden_size, + float eps = 1e-06f, + bool add_unit_offset = false) + : hidden_size(hidden_size), eps(eps), add_unit_offset(add_unit_offset) {} + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { + ggml_tensor* w = params["weight"]; + if (ctx->weight_adapter) { + w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, w, prefix + "weight"); + } + x = ggml_rms_norm(ctx->ggml_ctx, x, eps); + auto scaled = ggml_mul(ctx->ggml_ctx, x, w); + if (add_unit_offset) { + scaled = ggml_add_inplace(ctx->ggml_ctx, scaled, x); + } + return scaled; + } + }; + + struct MLP : public GGMLBlock { + protected: + MLPActivation activation; + + public: + MLP(int64_t hidden_size, + int64_t intermediate_size, + bool bias = false, + MLPActivation activation_ = MLPActivation::SILU) + : activation(activation_) { blocks["gate_proj"] = std::shared_ptr(new Linear(hidden_size, intermediate_size, bias)); blocks["up_proj"] = std::shared_ptr(new Linear(hidden_size, intermediate_size, bias)); blocks["down_proj"] = std::shared_ptr(new Linear(intermediate_size, hidden_size, bias)); @@ -84,9 +141,13 @@ namespace LLM { auto down_proj = std::dynamic_pointer_cast(blocks["down_proj"]); auto h = gate_proj->forward(ctx, x); - h = ggml_silu_inplace(ctx->ggml_ctx, h); - h = ggml_mul_inplace(ctx->ggml_ctx, h, up_proj->forward(ctx, x)); - h = down_proj->forward(ctx, h); + if (activation == MLPActivation::GELU_TANH) { + h = ggml_ext_gelu(ctx->ggml_ctx, h, true); + } else { + h = ggml_silu_inplace(ctx->ggml_ctx, h); + } + h = ggml_mul_inplace(ctx->ggml_ctx, h, up_proj->forward(ctx, x)); + h = down_proj->forward(ctx, h); return h; } }; @@ -377,24 +438,35 @@ namespace LLM { int64_t num_heads; int64_t num_kv_heads; bool qk_norm; + int64_t max_position_embeddings; + std::vector rope_thetas; + std::vector rope_scales; public: Attention(const LLMParams& params) - : arch(params.arch), num_heads(params.num_heads), num_kv_heads(params.num_kv_heads), head_dim(params.head_dim), qk_norm(params.qk_norm) { + : arch(params.arch), + num_heads(params.num_heads), + num_kv_heads(params.num_kv_heads), + head_dim(params.head_dim), + qk_norm(params.qk_norm), + max_position_embeddings(params.max_position_embeddings), + rope_thetas(params.rope_thetas), + rope_scales(params.rope_scales) { blocks["q_proj"] = std::make_shared(params.hidden_size, num_heads * head_dim, params.qkv_bias); blocks["k_proj"] = std::make_shared(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias); blocks["v_proj"] = std::make_shared(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias); blocks["o_proj"] = std::make_shared(num_heads * head_dim, params.hidden_size, false); if (params.qk_norm) { - blocks["q_norm"] = std::make_shared(head_dim, params.rms_norm_eps); - blocks["k_norm"] = std::make_shared(head_dim, params.rms_norm_eps); + blocks["q_norm"] = std::make_shared(head_dim, params.rms_norm_eps, params.rms_norm_add); + blocks["k_norm"] = std::make_shared(head_dim, params.rms_norm_eps, params.rms_norm_add); } } ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* input_pos, - ggml_tensor* attention_mask = nullptr) { + ggml_tensor* attention_mask = nullptr, + int rope_index = 0) { // x: [N, n_token, hidden_size] int64_t n_token = x->ne[1]; int64_t N = x->ne[2]; @@ -412,8 +484,8 @@ namespace LLM { v = ggml_reshape_4d(ctx->ggml_ctx, v, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim] if (qk_norm) { - auto q_norm = std::dynamic_pointer_cast(blocks["q_norm"]); - auto k_norm = std::dynamic_pointer_cast(blocks["k_norm"]); + auto q_norm = std::dynamic_pointer_cast(blocks["q_norm"]); + auto k_norm = std::dynamic_pointer_cast(blocks["k_norm"]); q = q_norm->forward(ctx, q); k = k_norm->forward(ctx, k); @@ -428,6 +500,36 @@ namespace LLM { } else if (arch == LLMArch::QWEN3) { q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 40960, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); + } else if (arch == LLMArch::GEMMA3_12B) { + float rope_theta = (rope_index == 1 ? 10000.0f : 1000000.0f); + float rope_scale = (rope_index == 1 ? 1.f : 8.f); + float freq_scale = 1.f / rope_scale; + q = ggml_rope_ext(ctx->ggml_ctx, + q, + input_pos, + nullptr, + head_dim, + GGML_ROPE_TYPE_NORMAL, + 0, + rope_theta, + freq_scale, + 0.f, + 1.f, + 32.f, + 1.f); + k = ggml_rope_ext(ctx->ggml_ctx, + k, + input_pos, + nullptr, + head_dim, + GGML_ROPE_TYPE_NORMAL, + 0, + rope_theta, + freq_scale, + 0.f, + 1.f, + 32.f, + 1.f); } else { int sections[4] = {16, 24, 24, 0}; q = ggml_rope_multi(ctx->ggml_ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); @@ -448,33 +550,76 @@ namespace LLM { }; struct TransformerBlock : public GGMLBlock { + protected: + LLMArch arch; + int sliding_attention; + bool has_post_attention_norm; + bool has_post_ffw_norm; + public: - TransformerBlock(const LLMParams& params) { + TransformerBlock(const LLMParams& params, int layer_index) + : arch(params.arch), + sliding_attention(0), + has_post_attention_norm(params.arch == LLMArch::GEMMA3_12B), + has_post_ffw_norm(params.arch == LLMArch::GEMMA3_12B) { blocks["self_attn"] = std::make_shared(params); - blocks["mlp"] = std::make_shared(params.hidden_size, params.intermediate_size); - blocks["input_layernorm"] = std::make_shared(params.hidden_size, params.rms_norm_eps); - blocks["post_attention_layernorm"] = std::make_shared(params.hidden_size, params.rms_norm_eps); + blocks["mlp"] = std::make_shared(params.hidden_size, + params.intermediate_size, + false, + params.mlp_activation); + blocks["input_layernorm"] = std::make_shared(params.hidden_size, params.rms_norm_eps, params.rms_norm_add); + blocks["post_attention_layernorm"] = std::make_shared(params.hidden_size, params.rms_norm_eps, params.rms_norm_add); + if (has_post_attention_norm) { + blocks["post_attention_norm"] = std::make_shared(params.hidden_size, params.rms_norm_eps, params.rms_norm_add); + } + if (has_post_ffw_norm) { + blocks["post_ffw_norm"] = std::make_shared(params.hidden_size, params.rms_norm_eps, params.rms_norm_add); + } + if (!params.sliding_attention.empty()) { + sliding_attention = params.sliding_attention[layer_index % params.sliding_attention.size()]; + } } ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* input_pos, - ggml_tensor* attention_mask = nullptr) { + ggml_tensor* attention_mask = nullptr, + ggml_tensor* sliding_attention_mask = nullptr) { // x: [N, n_token, hidden_size] - auto self_attn = std::dynamic_pointer_cast(blocks["self_attn"]); - auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); - auto input_layernorm = std::dynamic_pointer_cast(blocks["input_layernorm"]); - auto post_attention_layernorm = std::dynamic_pointer_cast(blocks["post_attention_layernorm"]); + auto self_attn = std::dynamic_pointer_cast(blocks["self_attn"]); + auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); + auto input_layernorm = std::dynamic_pointer_cast(blocks["input_layernorm"]); + auto post_attention_layernorm = std::dynamic_pointer_cast(blocks["post_attention_layernorm"]); + std::shared_ptr post_attention_norm = nullptr; + std::shared_ptr post_ffw_norm = nullptr; + if (has_post_attention_norm) { + post_attention_norm = std::dynamic_pointer_cast(blocks["post_attention_norm"]); + } + if (has_post_ffw_norm) { + post_ffw_norm = std::dynamic_pointer_cast(blocks["post_ffw_norm"]); + } + ggml_tensor* block_attention_mask = attention_mask; + int rope_index = 0; + if (arch == LLMArch::GEMMA3_12B && sliding_attention > 0) { + block_attention_mask = sliding_attention_mask; + rope_index = 1; + } auto residual = x; x = input_layernorm->forward(ctx, x); - x = self_attn->forward(ctx, x, input_pos, attention_mask); - x = ggml_add_inplace(ctx->ggml_ctx, x, residual); + x = self_attn->forward(ctx, x, input_pos, block_attention_mask, rope_index); + if (post_attention_norm != nullptr) { + x = post_attention_norm->forward(ctx, x); + } + x = ggml_add_inplace(ctx->ggml_ctx, x, residual); residual = x; x = post_attention_layernorm->forward(ctx, x); x = mlp->forward(ctx, x); - x = ggml_add_inplace(ctx->ggml_ctx, x, residual); + if (post_ffw_norm != nullptr) { + x = post_ffw_norm->forward(ctx, x); + } + x = ggml_add_inplace(ctx->ggml_ctx, x, residual); return x; } @@ -483,28 +628,36 @@ namespace LLM { struct TextModel : public GGMLBlock { protected: int64_t num_layers; + int64_t hidden_size; + bool normalize_input; + float input_scale; public: TextModel(const LLMParams& params) - : num_layers(params.num_layers) { + : num_layers(params.num_layers), + hidden_size(params.hidden_size), + normalize_input(params.normalize_input), + input_scale(std::sqrt(static_cast(params.hidden_size))) { blocks["embed_tokens"] = std::shared_ptr(new Embedding(params.vocab_size, params.hidden_size)); for (int i = 0; i < num_layers; i++) { - blocks["layers." + std::to_string(i)] = std::shared_ptr(new TransformerBlock(params)); + blocks["layers." + std::to_string(i)] = std::shared_ptr(new TransformerBlock(params, i)); } - blocks["norm"] = std::shared_ptr(new RMSNorm(params.hidden_size, params.rms_norm_eps)); + blocks["norm"] = std::shared_ptr(new LLMRMSNorm(params.hidden_size, params.rms_norm_eps, params.rms_norm_add)); } ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* input_ids, ggml_tensor* input_pos, ggml_tensor* attention_mask, + ggml_tensor* sliding_attention_mask, std::vector> image_embeds, - std::set out_layers) { + std::set out_layers, + bool return_all_hidden_states = false) { // input_ids: [N, n_token] // return: [N, n_token, hidden_size] auto embed_tokens = std::dynamic_pointer_cast(blocks["embed_tokens"]); - auto norm = std::dynamic_pointer_cast(blocks["norm"]); + auto norm = std::dynamic_pointer_cast(blocks["norm"]); auto x = embed_tokens->forward(ctx, input_ids); @@ -549,22 +702,44 @@ namespace LLM { x = input_embed; } + if (normalize_input) { + x = ggml_ext_scale(ctx->ggml_ctx, x, input_scale, true); + } + + if (return_all_hidden_states) { + intermediate_outputs.push_back(x); + } + for (int i = 0; i < num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["layers." + std::to_string(i)]); - x = block->forward(ctx, x, input_pos, attention_mask); - if (out_layers.find(i + 1) != out_layers.end()) { + x = block->forward(ctx, x, input_pos, attention_mask, sliding_attention_mask); + if (return_all_hidden_states) { + if (i + 1 < num_layers) { + intermediate_outputs.push_back(x); + } + } else if (out_layers.find(i + 1) != out_layers.end()) { intermediate_outputs.push_back(x); } } - if (!intermediate_outputs.empty()) { + auto normed_x = norm->forward(ctx, x); + if (return_all_hidden_states) { + intermediate_outputs.push_back(normed_x); + x = intermediate_outputs[0]; + for (int i = 1; i < intermediate_outputs.size(); i++) { + x = ggml_concat(ctx->ggml_ctx, x, intermediate_outputs[i], 0); + } + } else if (!intermediate_outputs.empty()) { + if (out_layers.find(static_cast(num_layers + 1)) != out_layers.end()) { + intermediate_outputs.push_back(normed_x); + } x = intermediate_outputs[0]; for (int i = 1; i < intermediate_outputs.size(); i++) { x = ggml_concat(ctx->ggml_ctx, x, intermediate_outputs[i], 0); } } else { - x = norm->forward(ctx, x); + x = normed_x; } return x; } @@ -599,12 +774,21 @@ namespace LLM { ggml_tensor* input_ids, ggml_tensor* input_pos, ggml_tensor* attention_mask, + ggml_tensor* sliding_attention_mask, std::vector> image_embeds, - std::set out_layers) { + std::set out_layers, + bool return_all_hidden_states = false) { // input_ids: [N, n_token] auto model = std::dynamic_pointer_cast(blocks["model"]); - auto x = model->forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers); + auto x = model->forward(ctx, + input_ids, + input_pos, + attention_mask, + sliding_attention_mask, + image_embeds, + out_layers, + return_all_hidden_states); return x; } @@ -627,6 +811,7 @@ namespace LLM { std::vector input_pos_vec; std::vector attention_mask_vec; + std::vector sliding_attention_mask_vec; std::vector window_mask_vec; std::vector window_index_vec; std::vector window_inverse_index_vec; @@ -653,6 +838,23 @@ namespace LLM { params.qkv_bias = false; params.qk_norm = true; params.rms_norm_eps = 1e-6f; + } else if (arch == LLMArch::GEMMA3_12B) { + params.head_dim = 256; + params.num_heads = 16; + params.num_kv_heads = 8; + params.qkv_bias = false; + params.qk_norm = true; + params.rms_norm_eps = 1e-6f; + // llama.cpp adds +1 to Gemma3 norm.weight when exporting GGUF, so GGUF loading + // must keep rms_norm_add disabled here or the offset gets applied twice. + // Convenient for the converter, less convenient for whoever gets to debug it later. + params.rms_norm_add = false; + params.normalize_input = true; + params.max_position_embeddings = 131072; + params.mlp_activation = MLPActivation::GELU_TANH; + params.rope_thetas = {1000000.f, 10000.f}; + params.rope_scales = {8.f, 1.f}; + params.sliding_attention = {1024, 1024, 1024, 1024, 1024, 0}; } bool have_vision_weight = false; bool llama_cpp_style = false; @@ -722,9 +924,18 @@ namespace LLM { ggml_tensor* input_ids, ggml_tensor* input_pos, ggml_tensor* attention_mask, + ggml_tensor* sliding_attention_mask, std::vector> image_embeds, - std::set out_layers) { - auto hidden_states = model.forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers); // [N, n_token, hidden_size] + std::set out_layers, + bool return_all_hidden_states = false) { + auto hidden_states = model.forward(ctx, + input_ids, + input_pos, + attention_mask, + sliding_attention_mask, + image_embeds, + out_layers, + return_all_hidden_states); // [N, n_token, hidden_size] return hidden_states; } @@ -741,8 +952,9 @@ namespace LLM { ggml_cgraph* build_graph(const sd::Tensor& input_ids_tensor, const sd::Tensor& attention_mask_tensor, const std::vector>>& image_embeds_tensor, - std::set out_layers) { - ggml_cgraph* gf = ggml_new_graph(compute_ctx); + std::set out_layers, + bool return_all_hidden_states = false) { + ggml_cgraph* gf = new_graph_custom(LLM_GRAPH_SIZE); ggml_tensor* input_ids = make_input(input_ids_tensor); std::vector> image_embeds; image_embeds.reserve(image_embeds_tensor.size()); @@ -752,7 +964,10 @@ namespace LLM { } int64_t n_tokens = input_ids->ne[0]; - if (params.arch == LLMArch::MISTRAL_SMALL_3_2 || params.arch == LLMArch::MINISTRAL_3_3B || params.arch == LLMArch::QWEN3) { + if (params.arch == LLMArch::MISTRAL_SMALL_3_2 || + params.arch == LLMArch::MINISTRAL_3_3B || + params.arch == LLMArch::QWEN3 || + params.arch == LLMArch::GEMMA3_12B) { input_pos_vec.resize(n_tokens); for (int i = 0; i < n_tokens; ++i) { input_pos_vec[i] = i; @@ -772,7 +987,8 @@ namespace LLM { input_pos_vec.size()); set_backend_tensor_data(input_pos, input_pos_vec.data()); - ggml_tensor* attention_mask = nullptr; + ggml_tensor* attention_mask = nullptr; + ggml_tensor* sliding_attention_mask = nullptr; if (!attention_mask_tensor.empty()) { attention_mask = make_input(attention_mask_tensor); } else { @@ -790,9 +1006,36 @@ namespace LLM { set_backend_tensor_data(attention_mask, attention_mask_vec.data()); } + if (params.arch == LLMArch::GEMMA3_12B) { + sliding_attention_mask_vec.resize(n_tokens * n_tokens); + if (!attention_mask_tensor.empty()) { + GGML_ASSERT(attention_mask_tensor.numel() == n_tokens * n_tokens); + sliding_attention_mask_vec = attention_mask_tensor.values(); + } else { + sliding_attention_mask_vec = attention_mask_vec; + } + for (int i0 = 0; i0 < n_tokens; i0++) { + for (int i1 = 0; i1 < n_tokens; i1++) { + if (i0 + 1024 <= i1) { + LOG_DEBUG("xxxxxxxxxxxxxx"); + sliding_attention_mask_vec[i1 * n_tokens + i0] = -INFINITY; + } + } + } + sliding_attention_mask = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, n_tokens, n_tokens); + set_backend_tensor_data(sliding_attention_mask, sliding_attention_mask_vec.data()); + } + auto runner_ctx = get_context(); - ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers); + ggml_tensor* hidden_states = forward(&runner_ctx, + input_ids, + input_pos, + attention_mask, + sliding_attention_mask, + image_embeds, + out_layers, + return_all_hidden_states); ggml_build_forward_expand(gf, hidden_states); @@ -803,9 +1046,14 @@ namespace LLM { const sd::Tensor& input_ids, const sd::Tensor& attention_mask, const std::vector>>& image_embeds, - std::set out_layers) { + std::set out_layers, + bool return_all_hidden_states = false) { auto get_graph = [&]() -> ggml_cgraph* { - return build_graph(input_ids, attention_mask, image_embeds, out_layers); + return build_graph(input_ids, + attention_mask, + image_embeds, + out_layers, + return_all_hidden_states); }; return take_or_empty(GGMLRunner::compute(get_graph, n_threads, true)); } diff --git a/src/ltx_vae.hpp b/src/ltx_vae.hpp new file mode 100644 index 00000000..baefe4e1 --- /dev/null +++ b/src/ltx_vae.hpp @@ -0,0 +1,971 @@ +#ifndef __SD_LTX_VAE_HPP__ +#define __SD_LTX_VAE_HPP__ + +#include +#include +#include +#include +#include +#include + +#include "ltxv.hpp" +#include "vae.hpp" +#include "wan.hpp" + +namespace LTXVAE { + + static inline ggml_tensor* apply_scale_shift(ggml_context* ctx, + ggml_tensor* x, + ggml_tensor* scale, + ggml_tensor* shift) { + x = ggml_add(ctx, x, ggml_mul(ctx, x, scale)); + x = ggml_add(ctx, x, shift); + return x; + } + + static inline ggml_tensor* reshape_channel_broadcast(ggml_context* ctx, + ggml_tensor* x) { + return ggml_reshape_4d(ctx, x, 1, 1, 1, ggml_nelements(x)); + } + + static inline std::pair get_shift_scale(ggml_context* ctx, + ggml_tensor* table, + ggml_tensor* timestep, + int64_t channels, + int parts) { + GGML_ASSERT(timestep != nullptr); + GGML_ASSERT(ggml_nelements(timestep) == channels * parts); + + auto timestep_view = ggml_reshape_2d(ctx, timestep, channels, parts); + auto values = ggml_add(ctx, table, timestep_view); + auto chunks = ggml_ext_chunk(ctx, values, parts, 1, false); + auto shift = reshape_channel_broadcast(ctx, ggml_cont(ctx, chunks[0])); + auto scale = reshape_channel_broadcast(ctx, ggml_cont(ctx, chunks[1])); + return {shift, scale}; + } + + static inline ggml_tensor* depth_to_space_3d(ggml_context* ctx, + ggml_tensor* x, + int64_t c, + int factor_t, + int factor_s, + bool drop_first_temporal_frame) { + // x: [B*c*p1*p2*p3, T, H, W], B == 1, p2 == p3 == factor_s, p1 == factor_t + // return: [B*c, T*p1, H*p2, W*p2] + // Match: rearrange(x, "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)") + const int64_t T = x->ne[2]; + const int64_t H = x->ne[1]; + const int64_t W = x->ne[0]; + + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 1, 3, 2)); // [T, C, H, W] + x = ggml_reshape_4d(ctx, x, W, H, factor_s, factor_s * factor_t * c * T); // [T*c*p1*p2, p3, H, W] + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [T*c*p1*p2, H, W, p3] + x = ggml_reshape_4d(ctx, x, factor_s * W, H, factor_s, factor_t * c * T); // [T*c*p1, p2, H, W*p3] + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [T*c*p1, H, p2, W*p3] + x = ggml_reshape_4d(ctx, x, factor_s * W * factor_s * H, factor_t, c, T); // [T, c, p1, H*p2*W*p3] + x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 1, 3, 2)); // [c, T, p1, H*p2*W*p3] + x = ggml_reshape_4d(ctx, x, factor_s * W, factor_s * H, factor_t * T, c); // [T, c, T*p1, H*p2*W*p3] + + if (drop_first_temporal_frame && factor_t > 1 && x->ne[2] > 0) { + x = ggml_ext_slice(ctx, x, 2, 1, x->ne[2]); + } + + return x; + } + + static inline ggml_tensor* patchify(ggml_context* ctx, + ggml_tensor* x, + int patch_size) { + return WAN::WanVAE::patchify(ctx, x, patch_size, 1); + } + + class CausalConv3d : public GGMLBlock { + protected: + int time_kernel_size; + + public: + CausalConv3d(int64_t in_channels, + int64_t out_channels, + int kernel_size = 3, + std::tuple stride = {1, 1, 1}, + int dilation = 1, + bool bias = true) { + time_kernel_size = kernel_size; + blocks["conv"] = std::shared_ptr(new Conv3d(in_channels, + out_channels, + {kernel_size, kernel_size, kernel_size}, + stride, + {0, kernel_size / 2, kernel_size / 2}, + {dilation, 1, 1}, + bias)); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* x, + bool causal = true) { + // x: [B*C, T, H, W], B == 1 + auto conv = std::dynamic_pointer_cast(blocks["conv"]); + + if (causal) { + auto first_frame = ggml_ext_slice(ctx->ggml_ctx, x, 2, 0, 1); + auto first_frame_pad = first_frame; + for (int i = 1; i < time_kernel_size - 1; i++) { + first_frame_pad = ggml_concat(ctx->ggml_ctx, first_frame_pad, first_frame, 2); + } + x = ggml_concat(ctx->ggml_ctx, first_frame_pad, x, 2); + } else { + auto first_frame = ggml_ext_slice(ctx->ggml_ctx, x, 2, 0, 1); + auto first_frame_pad = first_frame; + for (int i = 1; i < (time_kernel_size - 1) / 2; i++) { + first_frame_pad = ggml_concat(ctx->ggml_ctx, first_frame_pad, first_frame, 2); + } + + auto last_frame = ggml_ext_slice(ctx->ggml_ctx, x, 2, x->ne[2] - 1, x->ne[2]); + auto last_frame_pad = last_frame; + for (int i = 1; i < (time_kernel_size - 1) / 2; i++) { + last_frame_pad = ggml_concat(ctx->ggml_ctx, last_frame_pad, last_frame, 2); + } + x = ggml_concat(ctx->ggml_ctx, first_frame_pad, x, 2); + x = ggml_concat(ctx->ggml_ctx, x, last_frame_pad, 2); + } + return conv->forward(ctx, x); + } + }; + + struct PixelNorm3D : public UnaryBlock { + float eps; + + PixelNorm3D(float eps = 1e-8f) + : eps(eps) {} + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override { + auto h = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 3, 0, 1, 2)); + h = ggml_rms_norm(ctx->ggml_ctx, h, eps); + h = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, h, 1, 2, 3, 0)); + return h; + } + }; + + struct PixArtAlphaCombinedTimestepSizeEmbeddings : public GGMLBlock { + int64_t embedding_dim; + + PixArtAlphaCombinedTimestepSizeEmbeddings(int64_t embedding_dim) + : embedding_dim(embedding_dim) { + blocks["timestep_embedder"] = std::make_shared(embedding_dim); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* timestep) { + auto timestep_embedder = std::dynamic_pointer_cast(blocks["timestep_embedder"]); + return timestep_embedder->forward(ctx, timestep); + } + }; + + struct ResnetBlock3D : public GGMLBlock { + int64_t channels; + bool timestep_conditioning; + + protected: + void init_params(ggml_context* ctx, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "") override { + if (timestep_conditioning) { + params["scale_shift_table"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, channels, 4); + } + } + + public: + ResnetBlock3D(int64_t channels, + float eps = 1e-6f, + bool timestep_conditioning = false) + : channels(channels), timestep_conditioning(timestep_conditioning) { + blocks["norm1"] = std::make_shared(eps); + blocks["conv1"] = std::make_shared(channels, channels, 3); + blocks["norm2"] = std::make_shared(eps); + blocks["conv2"] = std::make_shared(channels, channels, 3); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* timestep = nullptr, + bool causal = false) { + auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); + auto conv1 = std::dynamic_pointer_cast(blocks["conv1"]); + auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]); + auto conv2 = std::dynamic_pointer_cast(blocks["conv2"]); + + ggml_tensor* shift1 = nullptr; + ggml_tensor* scale1 = nullptr; + ggml_tensor* shift2 = nullptr; + ggml_tensor* scale2 = nullptr; + if (timestep_conditioning) { + GGML_ASSERT(timestep != nullptr); + auto values = ggml_add(ctx->ggml_ctx, + params["scale_shift_table"], + ggml_reshape_2d(ctx->ggml_ctx, timestep, channels, 4)); + auto chunks = ggml_ext_chunk(ctx->ggml_ctx, values, 4, 1, false); + shift1 = reshape_channel_broadcast(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, chunks[0])); + scale1 = reshape_channel_broadcast(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, chunks[1])); + shift2 = reshape_channel_broadcast(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, chunks[2])); + scale2 = reshape_channel_broadcast(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, chunks[3])); + } + + auto h = norm1->forward(ctx, x); + if (timestep_conditioning) { + h = apply_scale_shift(ctx->ggml_ctx, h, scale1, shift1); + } + h = ggml_silu_inplace(ctx->ggml_ctx, h); + h = conv1->forward(ctx, h, causal); + + h = norm2->forward(ctx, h); + if (timestep_conditioning) { + h = apply_scale_shift(ctx->ggml_ctx, h, scale2, shift2); + } + h = ggml_silu_inplace(ctx->ggml_ctx, h); + h = conv2->forward(ctx, h, causal); + + return ggml_add(ctx->ggml_ctx, h, x); + } + }; + + struct UNetMidBlock3D : public GGMLBlock { + int64_t channels; + int num_layers; + bool timestep_conditioning; + + UNetMidBlock3D(int64_t channels, + int num_layers, + bool timestep_conditioning) + : channels(channels), + num_layers(num_layers), + timestep_conditioning(timestep_conditioning) { + if (timestep_conditioning) { + blocks["time_embedder"] = std::make_shared(channels * 4); + } + for (int i = 0; i < num_layers; i++) { + blocks["res_blocks." + std::to_string(i)] = std::make_shared(channels, 1e-6f, timestep_conditioning); + } + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* timestep = nullptr, + bool causal = false) { + ggml_tensor* timestep_embed = nullptr; + if (timestep_conditioning) { + GGML_ASSERT(timestep != nullptr); + auto time_embedder = std::dynamic_pointer_cast(blocks["time_embedder"]); + timestep_embed = time_embedder->forward(ctx, timestep); + } + + for (int i = 0; i < num_layers; i++) { + auto resnet = std::dynamic_pointer_cast(blocks["res_blocks." + std::to_string(i)]); + x = resnet->forward(ctx, x, timestep_embed, causal); + } + return x; + } + }; + + struct DepthToSpaceUpsample : public GGMLBlock { + int64_t in_channels; + int factor_t; + int factor_s; + int out_channels_reduction_factor; + bool residual; + + DepthToSpaceUpsample(int64_t in_channels, + int factor_t = 2, + int factor_s = 2, + int out_channels_reduction_factor = 2, + bool residual = true) + : in_channels(in_channels), + factor_t(factor_t), + factor_s(factor_s), + out_channels_reduction_factor(out_channels_reduction_factor), + residual(residual) { + const int64_t factor = static_cast(factor_t) * static_cast(factor_s) * static_cast(factor_s); + const int64_t out_dim = (factor * in_channels) / out_channels_reduction_factor; + blocks["conv"] = std::make_shared(in_channels, out_dim, 3); + } + + int64_t get_output_channels() const { + return in_channels / out_channels_reduction_factor; + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* x, + bool causal = false) { + auto conv = std::dynamic_pointer_cast(blocks["conv"]); + + ggml_tensor* x_in = nullptr; + if (residual) { + x_in = depth_to_space_3d(ctx->ggml_ctx, x, in_channels / (factor_t * factor_s * factor_s), factor_t, factor_s, factor_t > 1); + int repeat = (factor_t * factor_s * factor_s) / out_channels_reduction_factor; + auto res = x_in; + for (int i = 1; i < repeat; i++) { + res = ggml_concat(ctx->ggml_ctx, res, x_in, 3); + } + x_in = res; + } + + x = conv->forward(ctx, x, causal); + x = depth_to_space_3d(ctx->ggml_ctx, x, get_output_channels(), factor_t, factor_s, factor_t > 1); + if (residual) { + x = ggml_add(ctx->ggml_ctx, x, x_in); + } + return x; + } + }; + + struct SpaceToDepthDownsample : public GGMLBlock { + int64_t in_channels; + int64_t out_channels; + int factor_t; + int factor_s; + + SpaceToDepthDownsample(int64_t in_channels, + int64_t out_channels, + int factor_t, + int factor_s) + : in_channels(in_channels), + out_channels(out_channels), + factor_t(factor_t), + factor_s(factor_s) { + const int64_t factor = static_cast(factor_t) * static_cast(factor_s) * static_cast(factor_s); + GGML_ASSERT(out_channels % factor == 0); + + blocks["conv"] = std::make_shared(in_channels, out_channels / factor, 3); + blocks["skip_downsample"] = std::make_shared(in_channels, out_channels, factor_t, factor_s); + blocks["conv_downsample"] = std::make_shared(out_channels / factor, out_channels, factor_t, factor_s); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* x, + bool causal = true) { + auto conv = std::dynamic_pointer_cast(blocks["conv"]); + auto skip_downsample = std::dynamic_pointer_cast(blocks["skip_downsample"]); + auto conv_downsample = std::dynamic_pointer_cast(blocks["conv_downsample"]); + + if (factor_t > 1 && x->ne[2] > 0) { + auto first_frame = ggml_ext_slice(ctx->ggml_ctx, x, 2, 0, 1); + auto first_frame_pad = first_frame; + for (int i = 1; i < factor_t; ++i) { + first_frame_pad = ggml_concat(ctx->ggml_ctx, first_frame_pad, first_frame, 2); + } + x = ggml_concat(ctx->ggml_ctx, first_frame_pad, x, 2); + } + + auto residual = skip_downsample->forward(ctx, x); + auto h = conv->forward(ctx, x, causal); + h = conv_downsample->forward(ctx, h); + return ggml_add(ctx->ggml_ctx, h, residual); + } + }; + + struct PerChannelStatistics : public GGMLBlock { + protected: + void init_params(ggml_context* ctx, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "") override { + params["std-of-means"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 128); + params["mean-of-means"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 128); + } + + public: + ggml_tensor* un_normalize(GGMLRunnerContext* ctx, + ggml_tensor* x) { + auto std_tensor = reshape_channel_broadcast(ctx->ggml_ctx, params["std-of-means"]); + auto mean_tensor = reshape_channel_broadcast(ctx->ggml_ctx, params["mean-of-means"]); + return ggml_add(ctx->ggml_ctx, ggml_mul(ctx->ggml_ctx, x, std_tensor), mean_tensor); + } + + ggml_tensor* normalize(GGMLRunnerContext* ctx, + ggml_tensor* x) { + auto std_tensor = reshape_channel_broadcast(ctx->ggml_ctx, params["std-of-means"]); + auto mean_tensor = reshape_channel_broadcast(ctx->ggml_ctx, params["mean-of-means"]); + return ggml_div(ctx->ggml_ctx, ggml_sub(ctx->ggml_ctx, x, mean_tensor), std_tensor); + } + }; + + struct DecoderConfig { + struct Block { + std::string type; + int num_layers = 0; + int multiplier = 1; + }; + + std::vector blocks; + }; + + struct EncoderConfig { + struct Block { + std::string type; + int num_layers = 0; + int multiplier = 1; + }; + + std::vector blocks; + }; + + static inline bool has_tensor(const String2TensorStorage& tensor_storage_map, + const std::string& name) { + return tensor_storage_map.find(name) != tensor_storage_map.end(); + } + + static inline int64_t get_tensor_ne0(const String2TensorStorage& tensor_storage_map, + const std::string& name, + int64_t fallback = 0) { + auto iter = tensor_storage_map.find(name); + if (iter == tensor_storage_map.end()) { + return fallback; + } + return iter->second.ne[0]; + } + + static inline DecoderConfig infer_decoder_config_from_weights(const String2TensorStorage& tensor_storage_map, + const std::string& prefix, + int64_t conv_in_channels) { + DecoderConfig cfg; + const std::string decoder_prefix = prefix + ".decoder.up_blocks."; + + int64_t current_channels = conv_in_channels; + for (int block_idx = 0;; ++block_idx) { + const std::string block_prefix = decoder_prefix + std::to_string(block_idx); + const std::string res0_bias = block_prefix + ".res_blocks.0.conv1.conv.bias"; + const std::string conv_bias = block_prefix + ".conv.conv.bias"; + + if (has_tensor(tensor_storage_map, res0_bias)) { + int num_layers = 0; + while (has_tensor(tensor_storage_map, + block_prefix + ".res_blocks." + std::to_string(num_layers) + ".conv1.conv.bias")) { + num_layers++; + } + cfg.blocks.push_back({"res_x", num_layers, 1}); + current_channels = get_tensor_ne0(tensor_storage_map, res0_bias, current_channels); + continue; + } + + if (!has_tensor(tensor_storage_map, conv_bias)) { + break; + } + + int64_t next_channels = 0; + for (int next_idx = block_idx + 1;; ++next_idx) { + const std::string next_res0_bias = decoder_prefix + std::to_string(next_idx) + ".res_blocks.0.conv1.conv.bias"; + const std::string next_conv_bias = decoder_prefix + std::to_string(next_idx) + ".conv.conv.bias"; + if (has_tensor(tensor_storage_map, next_res0_bias)) { + next_channels = get_tensor_ne0(tensor_storage_map, next_res0_bias); + break; + } + if (!has_tensor(tensor_storage_map, next_conv_bias)) { + break; + } + } + if (next_channels <= 0 || current_channels % next_channels != 0) { + next_channels = std::max(1, current_channels / 2); + } + + const int64_t conv_out_dim = get_tensor_ne0(tensor_storage_map, conv_bias); + const int64_t reduction = std::max(1, current_channels / next_channels); + const int64_t factor = next_channels > 0 ? conv_out_dim / next_channels : 0; + + if (factor == 8) { + cfg.blocks.push_back({"compress_all", 0, static_cast(reduction)}); + } else if (factor == 4) { + cfg.blocks.push_back({"compress_space", 0, static_cast(reduction)}); + } else if (factor == 2) { + cfg.blocks.push_back({"compress_time", 0, static_cast(reduction)}); + } else { + LOG_WARN("unexpected LTX VAE upsample factor at '%s': conv_out=%lld current=%lld next=%lld, falling back to compress_all x%d", + block_prefix.c_str(), + (long long)conv_out_dim, + (long long)current_channels, + (long long)next_channels, + (int)reduction); + cfg.blocks.push_back({"compress_all", 0, static_cast(reduction)}); + } + current_channels = next_channels; + } + + return cfg; + } + + static inline int detect_ltx_vae_version(const String2TensorStorage& tensor_storage_map, + const std::string& prefix) { + const std::string v2_probe = prefix + ".encoder.down_blocks.1.conv.conv.bias"; + if (tensor_storage_map.find(v2_probe) != tensor_storage_map.end()) { + return 2; + } + return 1; + } + + static inline bool detect_ltx_vae_timestep_conditioning(const String2TensorStorage& tensor_storage_map, + const std::string& prefix) { + return tensor_storage_map.find(prefix + ".decoder.timestep_scale_multiplier") != tensor_storage_map.end(); + } + + static inline EncoderConfig get_encoder_config(int version) { + EncoderConfig cfg; + if (version < 2) { + GGML_ABORT("LTX VAE encoder is only implemented for version >= 2"); + } + + cfg.blocks = { + {"res_x", 4, 1}, + {"compress_space_res", 0, 2}, + {"res_x", 6, 1}, + {"compress_time_res", 0, 2}, + {"res_x", 6, 1}, + {"compress_all_res", 0, 2}, + {"res_x", 2, 1}, + {"compress_all_res", 0, 2}, + {"res_x", 2, 1}, + }; + return cfg; + } + + struct Encoder : public GGMLBlock { + int version; + int patch_size; + int64_t in_channels; + int64_t latent_channels; + + Encoder(int version, + int patch_size = 4, + int64_t in_channels = 3, + int64_t latent_channels = 128) + : version(version), + patch_size(patch_size), + in_channels(in_channels), + latent_channels(latent_channels) { + auto cfg = get_encoder_config(version); + int64_t channels = 128; + int64_t in_dim = in_channels * patch_size * patch_size; + + blocks["conv_in"] = std::make_shared(in_dim, channels, 3); + + for (int block_idx = 0; block_idx < static_cast(cfg.blocks.size()); ++block_idx) { + const auto& block = cfg.blocks[block_idx]; + if (block.type == "res_x") { + blocks["down_blocks." + std::to_string(block_idx)] = std::make_shared(channels, + block.num_layers, + false); + } else if (block.type == "compress_space_res") { + int64_t next_channels = channels * block.multiplier; + blocks["down_blocks." + std::to_string(block_idx)] = std::make_shared(channels, + next_channels, + 1, + 2); + channels = next_channels; + } else if (block.type == "compress_time_res") { + int64_t next_channels = channels * block.multiplier; + blocks["down_blocks." + std::to_string(block_idx)] = std::make_shared(channels, + next_channels, + 2, + 1); + channels = next_channels; + } else if (block.type == "compress_all_res") { + int64_t next_channels = channels * block.multiplier; + blocks["down_blocks." + std::to_string(block_idx)] = std::make_shared(channels, + next_channels, + 2, + 2); + channels = next_channels; + } else { + GGML_ABORT("Unsupported LTX VAE encoder block"); + } + } + + blocks["conv_norm_out"] = std::make_shared(); + blocks["conv_out"] = std::make_shared(channels, latent_channels + 1, 3); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* x) { + auto conv_in = std::dynamic_pointer_cast(blocks["conv_in"]); + auto conv_norm_out = std::dynamic_pointer_cast(blocks["conv_norm_out"]); + auto conv_out = std::dynamic_pointer_cast(blocks["conv_out"]); + + x = conv_in->forward(ctx, x, true); + + int block_idx = 0; + while (blocks.find("down_blocks." + std::to_string(block_idx)) != blocks.end()) { + auto mid_block = std::dynamic_pointer_cast(blocks["down_blocks." + std::to_string(block_idx)]); + if (mid_block) { + x = mid_block->forward(ctx, x, nullptr, true); + } else { + auto downsample = std::dynamic_pointer_cast(blocks["down_blocks." + std::to_string(block_idx)]); + x = downsample->forward(ctx, x, true); + } + block_idx++; + } + + x = conv_norm_out->forward(ctx, x); + x = ggml_silu_inplace(ctx->ggml_ctx, x); + x = conv_out->forward(ctx, x, true); + + auto last_channel = ggml_ext_slice(ctx->ggml_ctx, x, 3, x->ne[3] - 1, x->ne[3]); + auto repeat_shape = ggml_new_tensor_4d(ctx->ggml_ctx, last_channel->type, last_channel->ne[0], last_channel->ne[1], last_channel->ne[2], latent_channels - 1); + auto repeated = ggml_repeat(ctx->ggml_ctx, last_channel, repeat_shape); + return ggml_concat(ctx->ggml_ctx, x, repeated, 3); + } + }; + + struct Decoder : public GGMLBlock { + int version; + int patch_size; + bool causal_decoder; + bool timestep_conditioning; + int64_t in_channels; + int64_t hidden_channels; + + protected: + void init_params(ggml_context* ctx, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "") override { + if (timestep_conditioning) { + params["timestep_scale_multiplier"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); + params["last_scale_shift_table"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hidden_channels, 2); + } + } + + public: + Decoder(int version, + const String2TensorStorage& tensor_storage_map, + const std::string& prefix, + int patch_size = 4, + bool causal_decoder = false, + bool timestep_conditioning = true, + int64_t in_channels = 128, + int64_t hidden_channels = 128) + : version(version), + patch_size(patch_size), + causal_decoder(causal_decoder), + timestep_conditioning(timestep_conditioning), + in_channels(in_channels), + hidden_channels(hidden_channels) { + const int64_t conv_in_out_channels = get_tensor_ne0(tensor_storage_map, + prefix + ".decoder.conv_in.conv.bias", + hidden_channels); + auto cfg = infer_decoder_config_from_weights(tensor_storage_map, + prefix, + conv_in_out_channels); + int64_t channels = conv_in_out_channels; + + blocks["conv_in"] = std::make_shared(in_channels, channels, 3); + + for (int block_idx = 0; block_idx < static_cast(cfg.blocks.size()); ++block_idx) { + const auto& block = cfg.blocks[block_idx]; + if (block.type == "res_x") { + blocks["up_blocks." + std::to_string(block_idx)] = std::make_shared(channels, + block.num_layers, + timestep_conditioning); + } else if (block.type == "compress_all") { + blocks["up_blocks." + std::to_string(block_idx)] = std::make_shared(channels, + 2, + 2, + block.multiplier, + false); + channels /= block.multiplier; + } else if (block.type == "compress_time") { + blocks["up_blocks." + std::to_string(block_idx)] = std::make_shared(channels, + 2, + 1, + block.multiplier, + false); + channels /= block.multiplier; + } else if (block.type == "compress_space") { + blocks["up_blocks." + std::to_string(block_idx)] = std::make_shared(channels, + 1, + 2, + block.multiplier, + false); + channels /= block.multiplier; + } else { + GGML_ABORT("Unsupported LTX VAE decoder block"); + } + } + + hidden_channels = channels; + blocks["conv_norm_out"] = std::make_shared(); + blocks["conv_out"] = std::make_shared(hidden_channels, 3 * patch_size * patch_size, 3); + if (timestep_conditioning) { + blocks["last_time_embedder"] = std::make_shared(hidden_channels * 2); + } + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* timestep) { + auto conv_in = std::dynamic_pointer_cast(blocks["conv_in"]); + auto conv_norm_out = std::dynamic_pointer_cast(blocks["conv_norm_out"]); + auto conv_out = std::dynamic_pointer_cast(blocks["conv_out"]); + + ggml_tensor* scaled_timestep = timestep; + if (timestep_conditioning) { + auto multiplier = ggml_ext_backend_tensor_get_f32(params["timestep_scale_multiplier"]); + scaled_timestep = ggml_ext_scale(ctx->ggml_ctx, timestep, multiplier); + } + + x = conv_in->forward(ctx, x, causal_decoder); + + int block_idx = 0; + while (blocks.find("up_blocks." + std::to_string(block_idx)) != blocks.end()) { + auto mid_block = std::dynamic_pointer_cast(blocks["up_blocks." + std::to_string(block_idx)]); + if (mid_block) { + x = mid_block->forward(ctx, x, scaled_timestep, causal_decoder); + } else { + auto upsample = std::dynamic_pointer_cast(blocks["up_blocks." + std::to_string(block_idx)]); + x = upsample->forward(ctx, x, causal_decoder); + } + block_idx++; + } + + x = conv_norm_out->forward(ctx, x); + if (timestep_conditioning) { + auto last_time_embedder = std::dynamic_pointer_cast(blocks["last_time_embedder"]); + auto timestep_embed = last_time_embedder->forward(ctx, scaled_timestep); + auto [shift, scale] = get_shift_scale(ctx->ggml_ctx, + params["last_scale_shift_table"], + timestep_embed, + hidden_channels, + 2); + x = apply_scale_shift(ctx->ggml_ctx, x, scale, shift); + } + x = ggml_silu_inplace(ctx->ggml_ctx, x); + x = conv_out->forward(ctx, x, causal_decoder); + return x; + } + }; + + struct VideoVAE : public GGMLBlock { + int version; + float decode_timestep; + bool timestep_conditioning; + int patch_size; + bool decode_only; + + VideoVAE(int version, + bool decode_only, + bool timestep_conditioning, + int patch_size, + const String2TensorStorage& tensor_storage_map, + const std::string& prefix, + float decode_timestep = 0.05f) + : version(version), + decode_timestep(decode_timestep), + timestep_conditioning(timestep_conditioning), + patch_size(patch_size), + decode_only(decode_only) { + if (!decode_only) { + blocks["encoder"] = std::make_shared(version, patch_size); + } + blocks["decoder"] = std::make_shared(version, + tensor_storage_map, + prefix, + patch_size, + false, + timestep_conditioning); + blocks["per_channel_statistics"] = std::make_shared(); + } + + ggml_tensor* decode(GGMLRunnerContext* ctx, + ggml_tensor* z, + ggml_tensor* timestep) { + auto decoder = std::dynamic_pointer_cast(blocks["decoder"]); + auto processor = std::dynamic_pointer_cast(blocks["per_channel_statistics"]); + auto latents = processor->un_normalize(ctx, z); + auto out = decoder->forward(ctx, latents, timestep); + out = WAN::WanVAE::unpatchify(ctx->ggml_ctx, out, patch_size, 1); + return out; + } + + ggml_tensor* encode(GGMLRunnerContext* ctx, + ggml_tensor* x) { + GGML_ASSERT(!decode_only); + auto encoder = std::dynamic_pointer_cast(blocks["encoder"]); + auto processor = std::dynamic_pointer_cast(blocks["per_channel_statistics"]); + + x = patchify(ctx->ggml_ctx, x, patch_size); + auto out = encoder->forward(ctx, x); + auto mean = ggml_ext_chunk(ctx->ggml_ctx, out, 2, 3, false)[0]; + mean = ggml_cont(ctx->ggml_ctx, mean); + return processor->normalize(ctx, mean); + } + }; + +} // namespace LTXVAE + +struct LTXVideoVAE : public VAE { + bool decode_only; + int ltx_vae_version; + bool timestep_conditioning; + int patch_size; + sd::Tensor decode_timestep_tensor; + LTXVAE::VideoVAE vae; + + LTXVideoVAE(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map, + const std::string& prefix, + bool decode_only = true, + SDVersion version = VERSION_LTXAV) + : decode_only(decode_only), + ltx_vae_version(LTXVAE::detect_ltx_vae_version(tensor_storage_map, prefix)), + timestep_conditioning(LTXVAE::detect_ltx_vae_timestep_conditioning(tensor_storage_map, prefix)), + patch_size(4), + decode_timestep_tensor(sd::Tensor::from_vector({0.05f})), + vae(LTXVAE::detect_ltx_vae_version(tensor_storage_map, prefix), + decode_only, + LTXVAE::detect_ltx_vae_timestep_conditioning(tensor_storage_map, prefix), + patch_size, + tensor_storage_map, + prefix), + VAE(version, backend, offload_params_to_cpu) { + vae.init(params_ctx, tensor_storage_map, prefix); + decode_timestep_tensor.values()[0] = vae.decode_timestep; + } + + std::string get_desc() override { + return "ltx_video_vae"; + } + + void get_param_tensors(std::map& tensors, const std::string prefix) override { + vae.get_param_tensors(tensors, prefix); + } + + ggml_cgraph* build_graph(const sd::Tensor& z_tensor, bool decode_graph) { + LOG_DEBUG("ltx_video_vae build_graph input %dx%dx%dx%d", + (int)z_tensor.shape()[0], + (int)z_tensor.shape()[1], + (int)z_tensor.shape()[2], + (int)z_tensor.shape()[3]); + ggml_cgraph* gf = ggml_new_graph(compute_ctx); + ggml_tensor* z = make_input(z_tensor); + ggml_tensor* timestep = nullptr; + if (timestep_conditioning) { + timestep = make_input(decode_timestep_tensor); + } + + auto runner_ctx = get_context(); + ggml_tensor* out = decode_graph ? vae.decode(&runner_ctx, z, timestep) : vae.encode(&runner_ctx, z); + LOG_DEBUG("ltx_video_vae build_graph output ne=[%lld,%lld,%lld,%lld]", + (long long)out->ne[0], + (long long)out->ne[1], + (long long)out->ne[2], + (long long)out->ne[3]); + ggml_build_forward_expand(gf, out); + + return gf; + } + + sd::Tensor _compute(const int n_threads, + const sd::Tensor& z, + bool decode_graph) override { + if (!decode_graph && decode_only) { + LOG_ERROR("LTX video VAE encoder is not implemented yet"); + return {}; + } + sd::Tensor input = z; + size_t expected_dim = static_cast(z.dim()); + if (!decode_graph) { + if (input.dim() == 4) { + input = input.unsqueeze(2); + expected_dim = 5; + } else if (input.dim() != 5) { + LOG_ERROR("LTX video VAE encoder expects 4D image or 5D video input, got dim=%lld", + (long long)input.dim()); + return {}; + } + + int64_t cropped_t = std::max(1, 1 + ((input.shape()[2] - 1) / 8) * 8); + if (cropped_t != input.shape()[2]) { + input = sd::ops::slice(input, 2, 0, cropped_t); + } + } + auto get_graph = [&]() -> ggml_cgraph* { + return build_graph(input, decode_graph); + }; + auto result = restore_trailing_singleton_dims(GGMLRunner::compute(get_graph, n_threads, false), expected_dim); + if (result.empty()) { + return {}; + } + LOG_DEBUG("ltx_video_vae host output shape=[%lld,%lld,%lld,%lld] dim=%lld", + (long long)(result.shape().size() > 0 ? result.shape()[0] : 0), + (long long)(result.shape().size() > 1 ? result.shape()[1] : 0), + (long long)(result.shape().size() > 2 ? result.shape()[2] : 0), + (long long)(result.shape().size() > 3 ? result.shape()[3] : 0), + (long long)result.dim()); + return result; + } + + int get_encoder_output_channels(int input_channels) override { + SD_UNUSED(input_channels); + return 256; + } + + sd::Tensor vae_output_to_latents(const sd::Tensor& vae_output, std::shared_ptr rng) override { + SD_UNUSED(rng); + if (vae_output.dim() >= 4 && vae_output.shape()[3] > 128) { + return sd::ops::slice(vae_output, 3, 0, 128); + } + return vae_output; + } + + sd::Tensor diffusion_to_vae_latents(const sd::Tensor& latents) override { + return latents; + } + + sd::Tensor vae_to_diffusion_latents(const sd::Tensor& latents) override { + return latents; + } + + void test(const std::string& input_path) { + auto z = sd::load_tensor_from_file_as_tensor(input_path); + print_sd_tensor(z, false, "ltx_vae_z"); + + z = diffusion_to_vae_latents(z); + + int64_t t0 = ggml_time_ms(); + auto out = _compute(8, z, true); + int64_t t1 = ggml_time_ms(); + + GGML_ASSERT(!out.empty()); + print_sd_tensor(out, false, "ltx_vae_out"); + LOG_DEBUG("ltx vae test done in %lldms", t1 - t0); + } + + static void load_from_file_and_test(const std::string& model_path, + const std::string& input_path) { + // ggml_backend_t backend = ggml_backend_cuda_init(0); + ggml_backend_t backend = ggml_backend_cpu_init(); + LOG_INFO("loading ltx vae from '%s'", model_path.c_str()); + + ModelLoader model_loader; + if (!model_loader.init_from_file_and_convert_name(model_path, "vae.")) { + LOG_ERROR("init model loader from file failed: '%s'", model_path.c_str()); + return; + } + + auto& tensor_storage_map = model_loader.get_tensor_storage_map(); + std::shared_ptr vae = std::make_shared(backend, + false, + tensor_storage_map, + "first_stage_model", + true, + VERSION_LTXAV); + + vae->alloc_params_buffer(); + std::map tensors; + vae->get_param_tensors(tensors, "first_stage_model"); + + if (!model_loader.load_tensors(tensors)) { + LOG_ERROR("load tensors from model loader failed"); + return; + } + + LOG_INFO("ltx vae model loaded"); + vae->test(input_path); + } +}; + +#endif // __SD_LTX_VAE_HPP__ diff --git a/src/ltxv.hpp b/src/ltxv.hpp index fb37dbe0..c79a1726 100644 --- a/src/ltxv.hpp +++ b/src/ltxv.hpp @@ -1,73 +1,1886 @@ -#ifndef __LTXV_HPP__ -#define __LTXV_HPP__ +#ifndef __SD_LTXV_HPP__ +#define __SD_LTXV_HPP__ + +#include +#include +#include +#include +#include +#include +#include #include "common_block.hpp" +#include "flux.hpp" +#include "rope.hpp" namespace LTXV { - class CausalConv3d : public GGMLBlock { - protected: - int time_kernel_size; + constexpr int LTXAV_GRAPH_SIZE = 102400; - public: - CausalConv3d(int64_t in_channels, - int64_t out_channels, - int kernel_size = 3, - std::tuple stride = {1, 1, 1}, - int dilation = 1, - bool bias = true) { - time_kernel_size = kernel_size / 2; - blocks["conv"] = std::shared_ptr(new Conv3d(in_channels, - out_channels, - {kernel_size, kernel_size, kernel_size}, - stride, - {0, kernel_size / 2, kernel_size / 2}, - {dilation, 1, 1}, - bias)); + __STATIC_INLINE__ ggml_tensor* rms_norm(ggml_context* ctx, + ggml_tensor* x, + float eps = 1e-6f) { + return ggml_rms_norm(ctx, x, eps); + } + + __STATIC_INLINE__ ggml_tensor* apply_gate(ggml_context* ctx, + ggml_tensor* x, + ggml_tensor* gate) { + if (gate->ne[1] != 1) { + gate = ggml_reshape_3d(ctx, gate, gate->ne[0], 1, gate->ne[2]); + } + return ggml_mul(ctx, x, gate); + } + + __STATIC_INLINE__ int count_prefix_blocks(const String2TensorStorage& tensor_storage_map, + const std::string& prefix, + const std::string& marker) { + int max_block = -1; + for (const auto& [name, _] : tensor_storage_map) { + if (!starts_with(name, prefix)) { + continue; + } + size_t pos = name.find(marker); + if (pos == std::string::npos) { + continue; + } + pos += marker.size(); + size_t end = name.find(".", pos); + if (end == std::string::npos) { + continue; + } + int block = atoi(name.substr(pos, end - pos).c_str()); + max_block = std::max(max_block, block); + } + return max_block + 1; + } + + __STATIC_INLINE__ std::vector generate_freq_grid(float theta, + int positional_dims, + int dim) { + const int n_elem = 2 * positional_dims; + const int freq_count = dim / n_elem; + + std::vector out(freq_count); + if (freq_count <= 0) { + return out; + } + if (freq_count == 1) { + out[0] = 1.5707963267948966f; + return out; + } + + const float half_pi = 1.5707963267948966f; + const float log_theta = std::log(theta); + for (int i = 0; i < freq_count; i++) { + float ratio = static_cast(i) / static_cast(freq_count - 1); + out[i] = std::exp(log_theta * ratio) * half_pi; + } + return out; + } + + __STATIC_INLINE__ std::vector generate_freq_grid_double(double theta, + int positional_dims, + int dim) { + const int n_elem = 2 * positional_dims; + const int freq_count = dim / n_elem; + + std::vector out(freq_count); + if (freq_count <= 0) { + return out; + } + if (freq_count == 1) { + out[0] = 1.5707963267948966; + return out; + } + + const double half_pi = 1.5707963267948966; + const double log_theta = std::log(theta); + for (int i = 0; i < freq_count; i++) { + double ratio = static_cast(i) / static_cast(freq_count - 1); + out[i] = std::exp(log_theta * ratio) * half_pi; + } + return out; + } + + __STATIC_INLINE__ std::vector build_rope_matrix_from_frequencies( + const std::vector>& frequencies, + int dim) { + const int half_dim = dim / 2; + std::vector out(static_cast(frequencies.size()) * static_cast(half_dim) * 4, 0.f); + + for (size_t token = 0; token < frequencies.size(); token++) { + for (int i = 0; i < half_dim; i++) { + float angle = i < static_cast(frequencies[token].size()) ? frequencies[token][i] : 0.f; + float c = std::cos(angle); + float s = std::sin(angle); + + size_t base = (token * static_cast(half_dim) + static_cast(i)) * 4; + out[base + 0] = c; + out[base + 1] = -s; + out[base + 2] = s; + out[base + 3] = c; + } + } + + return out; + } + + __STATIC_INLINE__ std::vector> split_frequencies_by_heads( + const std::vector>& frequencies, + int inner_dim, + int num_heads) { + GGML_ASSERT(num_heads > 0); + GGML_ASSERT(inner_dim % num_heads == 0); + const int inner_half_dim = inner_dim / 2; + const int per_head_half_dim = inner_half_dim / num_heads; + GGML_ASSERT(inner_half_dim % num_heads == 0); + + std::vector> out( + frequencies.size() * static_cast(num_heads), + std::vector(per_head_half_dim, 0.f)); + + for (size_t token = 0; token < frequencies.size(); token++) { + GGML_ASSERT(static_cast(frequencies[token].size()) == inner_half_dim); + for (int head = 0; head < num_heads; head++) { + auto& dst = out[token * static_cast(num_heads) + static_cast(head)]; + std::copy_n(frequencies[token].begin() + head * per_head_half_dim, per_head_half_dim, dst.begin()); + } + } + return out; + } + + __STATIC_INLINE__ std::vector build_video_rope_matrix(int64_t width, + int64_t height, + int64_t frames, + int dim, + int num_heads = 1, + float frame_rate = 25.f, + float theta = 10000.f, + const std::vector& max_pos = {20, 2048, 2048}, + const std::tuple& vae_scale_factors = {8, 32, 32}, + bool causal_temporal_positioning = false, + bool use_middle_indices_grid = false) { + GGML_ASSERT(max_pos.size() == 3); + GGML_ASSERT(dim % num_heads == 0); + const std::vector indices = generate_freq_grid(theta, 3, dim); + const int half_dim = dim / 2; + const int pad_size = half_dim - static_cast(indices.size()) * 3; + + std::vector> freqs(static_cast(width * height * frames), std::vector(half_dim, 0.f)); + + const int scale_t = std::get<0>(vae_scale_factors); + const int scale_h = std::get<1>(vae_scale_factors); + const int scale_w = std::get<2>(vae_scale_factors); + + size_t token = 0; + for (int64_t t = 0; t < frames; t++) { + float pixel_t = static_cast(t * scale_t); + if (causal_temporal_positioning) { + pixel_t = std::max(0.f, pixel_t + 1.f - scale_t); + } + pixel_t /= frame_rate; + if (use_middle_indices_grid) { + float end = static_cast((t + 1) * scale_t); + if (causal_temporal_positioning) { + end = std::max(0.f, end + 1.f - scale_t); + } + end /= frame_rate; + pixel_t = 0.5f * (pixel_t + end); + } + + for (int64_t h = 0; h < height; h++) { + float pixel_h = static_cast(h * scale_h); + if (use_middle_indices_grid) { + pixel_h += 0.5f * static_cast(scale_h); + } + for (int64_t w = 0; w < width; w++) { + float pixel_w = static_cast(w * scale_w); + if (use_middle_indices_grid) { + pixel_w += 0.5f * static_cast(scale_w); + } + + int out_idx = 0; + for (int i = 0; i < pad_size; i++) { + freqs[token][out_idx++] = 0.f; + } + + const float coords[3] = { + pixel_t / max_pos[0], + pixel_h / max_pos[1], + pixel_w / max_pos[2], + }; + + for (float index : indices) { + for (int axis = 0; axis < 3; axis++) { + freqs[token][out_idx++] = index * (coords[axis] * 2.f - 1.f); + } + } + token++; + } + } + } + + if (num_heads > 1) { + return build_rope_matrix_from_frequencies(split_frequencies_by_heads(freqs, dim, num_heads), dim / num_heads); + } + return build_rope_matrix_from_frequencies(freqs, dim); + } + + __STATIC_INLINE__ std::vector build_1d_rope_matrix(int64_t seq_len, + int dim, + int num_heads = 1, + float theta = 10000.f, + float positional_scale = 4096.f, + bool double_precision = false) { + GGML_ASSERT(dim % num_heads == 0); + const std::vector indices = double_precision ? std::vector() : generate_freq_grid(theta, 1, dim); + const std::vector indices_d = + double_precision ? generate_freq_grid_double(static_cast(theta), 1, dim) : std::vector(); + const int half_dim = dim / 2; + const int pad_size = half_dim - static_cast(double_precision ? indices_d.size() : indices.size()); + + std::vector> freqs(static_cast(seq_len), std::vector(half_dim, 0.f)); + for (int64_t pos = 0; pos < seq_len; pos++) { + int out_idx = 0; + for (int i = 0; i < pad_size; i++) { + freqs[static_cast(pos)][out_idx++] = 0.f; + } + + if (double_precision) { + double coord = static_cast(pos) / static_cast(positional_scale); + for (double index : indices_d) { + freqs[static_cast(pos)][out_idx++] = static_cast(index * (coord * 2.0 - 1.0)); + } + } else { + float coord = static_cast(pos) / positional_scale; + for (float index : indices) { + freqs[static_cast(pos)][out_idx++] = index * (coord * 2.f - 1.f); + } + } + } + + if (num_heads > 1) { + return build_rope_matrix_from_frequencies(split_frequencies_by_heads(freqs, dim, num_heads), dim / num_heads); + } + return build_rope_matrix_from_frequencies(freqs, dim); + } + + __STATIC_INLINE__ ggml_tensor* apply_hidden_rope(ggml_context* ctx, + ggml_tensor* x, + ggml_tensor* pe, + int64_t heads, + int64_t dim_head, + bool rope_interleaved) { + GGML_ASSERT(x->ne[0] == heads * dim_head); + auto x4 = ggml_reshape_4d(ctx, x, dim_head, heads, x->ne[1], x->ne[2]); + if (pe != nullptr && pe->ne[3] == x->ne[1] * heads) { + auto x_flat = ggml_reshape_4d(ctx, x4, dim_head, 1, x->ne[1] * heads, x->ne[2]); + auto out_flat = Rope::apply_rope(ctx, x_flat, pe, rope_interleaved); + auto out4 = ggml_reshape_4d(ctx, out_flat, dim_head, heads, x->ne[1], x->ne[2]); + return ggml_reshape_3d(ctx, out4, heads * dim_head, x->ne[1], x->ne[2]); + } + return Rope::apply_rope(ctx, x4, pe, rope_interleaved); + } + + struct TimestepEmbedder : public GGMLBlock { + int frequency_embedding_size; + + TimestepEmbedder(int64_t hidden_size, + int frequency_embedding_size = 256) + : frequency_embedding_size(frequency_embedding_size) { + blocks["linear_1"] = std::make_shared(frequency_embedding_size, hidden_size, true, true); + blocks["linear_2"] = std::make_shared(hidden_size, hidden_size, true, true); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* timestep) { + auto linear_1 = std::dynamic_pointer_cast(blocks["linear_1"]); + auto linear_2 = std::dynamic_pointer_cast(blocks["linear_2"]); + + auto t_emb = ggml_ext_timestep_embedding(ctx->ggml_ctx, timestep, frequency_embedding_size); + t_emb = linear_1->forward(ctx, t_emb); + t_emb = ggml_silu_inplace(ctx->ggml_ctx, t_emb); + t_emb = linear_2->forward(ctx, t_emb); + return t_emb; + } + }; + + struct AdaLayerNormSingle : public GGMLBlock { + int64_t embedding_dim; + int64_t embedding_coefficient; + + AdaLayerNormSingle(int64_t embedding_dim, + int64_t embedding_coefficient = 6) + : embedding_dim(embedding_dim), embedding_coefficient(embedding_coefficient) { + blocks["emb.timestep_embedder"] = std::make_shared(embedding_dim); + blocks["linear"] = std::make_shared(embedding_dim, + embedding_coefficient * embedding_dim, + true, + true); + } + + std::pair forward(GGMLRunnerContext* ctx, + ggml_tensor* timestep) { + auto timestep_embedder = std::dynamic_pointer_cast(blocks["emb.timestep_embedder"]); + auto linear = std::dynamic_pointer_cast(blocks["linear"]); + + auto embedded_timestep = timestep_embedder->forward(ctx, timestep); + auto hidden = ggml_silu(ctx->ggml_ctx, embedded_timestep); + auto out = linear->forward(ctx, hidden); + return {out, embedded_timestep}; + } + }; + + struct PixArtAlphaTextProjection : public GGMLBlock { + PixArtAlphaTextProjection(int64_t in_features, + int64_t hidden_size, + int64_t out_features = -1) { + if (out_features < 0) { + out_features = hidden_size; + } + blocks["linear_1"] = std::make_shared(in_features, hidden_size, true, true); + blocks["linear_2"] = std::make_shared(hidden_size, out_features, true, true); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* caption) { + auto linear_1 = std::dynamic_pointer_cast(blocks["linear_1"]); + auto linear_2 = std::dynamic_pointer_cast(blocks["linear_2"]); + + caption = linear_1->forward(ctx, caption); + caption = ggml_ext_gelu(ctx->ggml_ctx, caption, true); + caption = linear_2->forward(ctx, caption); + return caption; + } + }; + + struct NormSingleLinearTextProjection : public GGMLBlock { + int64_t in_features; + int64_t hidden_size; + + NormSingleLinearTextProjection(int64_t in_features, + int64_t hidden_size) + : in_features(in_features), hidden_size(hidden_size) { + blocks["linear_1"] = std::make_shared(in_features, hidden_size, true, true); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* caption) { + auto linear_1 = std::dynamic_pointer_cast(blocks["linear_1"]); + caption = ggml_rms_norm(ctx->ggml_ctx, caption, 1e-6f); + caption = ggml_ext_scale(ctx->ggml_ctx, caption, std::sqrt(static_cast(hidden_size) / static_cast(in_features))); + return linear_1->forward(ctx, caption); + } + }; + + struct CrossAttention : public GGMLBlock { + int64_t heads; + int64_t dim_head; + bool rope_interleaved; + + CrossAttention(int64_t query_dim, + int64_t context_dim, + int64_t heads, + int64_t dim_head, + bool apply_gated_attention = false, + bool rope_interleaved = true) + : heads(heads), dim_head(dim_head), rope_interleaved(rope_interleaved) { + int64_t inner_dim = heads * dim_head; + blocks["q_norm"] = std::make_shared(inner_dim, 1e-5f); + blocks["k_norm"] = std::make_shared(inner_dim, 1e-5f); + blocks["to_q"] = std::make_shared(query_dim, inner_dim, true); + blocks["to_k"] = std::make_shared(context_dim, inner_dim, true); + blocks["to_v"] = std::make_shared(context_dim, inner_dim, true); + if (apply_gated_attention) { + blocks["to_gate_logits"] = std::make_shared(query_dim, heads, true); + } + blocks["to_out.0"] = std::make_shared(inner_dim, query_dim, true); } ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, - bool causal = true) { - // x: [N*IC, ID, IH, IW] - // result: [N*OC, OD, OH, OW] - auto conv = std::dynamic_pointer_cast(blocks["conv"]); - if (causal) { - auto h = ggml_cont(ctx, ggml_permute(ctx, x, 0, 1, 3, 2)); // [ID, N*IC, IH, IW] - auto first_frame = ggml_view_3d(ctx, h, h->ne[0], h->ne[1], h->ne[2], h->nb[1], h->nb[2], 0); // [N*IC, IH, IW] - first_frame = ggml_reshape_4d(ctx, first_frame, first_frame->ne[0], first_frame->ne[1], 1, first_frame->ne[2]); // [N*IC, 1, IH, IW] - auto first_frame_pad = first_frame; - for (int i = 1; i < time_kernel_size - 1; i++) { - first_frame_pad = ggml_concat(ctx, first_frame_pad, first_frame, 2); - } - x = ggml_concat(ctx, first_frame_pad, x, 2); - } else { - auto h = ggml_cont(ctx, ggml_permute(ctx, x, 0, 1, 3, 2)); // [ID, N*IC, IH, IW] - int64_t offset = h->nb[2] * h->ne[2]; - - auto first_frame = ggml_view_3d(ctx, h, h->ne[0], h->ne[1], h->ne[2], h->nb[1], h->nb[2], 0); // [N*IC, IH, IW] - first_frame = ggml_reshape_4d(ctx, first_frame, first_frame->ne[0], first_frame->ne[1], 1, first_frame->ne[2]); // [N*IC, 1, IH, IW] - auto first_frame_pad = first_frame; - for (int i = 1; i < (time_kernel_size - 1) / 2; i++) { - first_frame_pad = ggml_concat(ctx, first_frame_pad, first_frame, 2); - } - - auto last_frame = ggml_view_3d(ctx, h, h->ne[0], h->ne[1], h->ne[2], h->nb[1], h->nb[2], offset * (h->ne[3] - 1)); // [N*IC, IH, IW] - last_frame = ggml_reshape_4d(ctx, last_frame, last_frame->ne[0], last_frame->ne[1], 1, last_frame->ne[2]); // [N*IC, 1, IH, IW] - auto last_frame_pad = last_frame; - for (int i = 1; i < (time_kernel_size - 1) / 2; i++) { - last_frame_pad = ggml_concat(ctx, last_frame_pad, last_frame, 2); - } - - x = ggml_concat(ctx, first_frame_pad, x, 2); - x = ggml_concat(ctx, x, last_frame_pad, 2); + ggml_tensor* context = nullptr, + ggml_tensor* mask = nullptr, + ggml_tensor* pe = nullptr, + ggml_tensor* k_pe = nullptr) { + if (context == nullptr) { + context = x; } - x = conv->forward(ctx, x); + auto to_q = std::dynamic_pointer_cast(blocks["to_q"]); + auto to_k = std::dynamic_pointer_cast(blocks["to_k"]); + auto to_v = std::dynamic_pointer_cast(blocks["to_v"]); + auto q_norm = std::dynamic_pointer_cast(blocks["q_norm"]); + auto k_norm = std::dynamic_pointer_cast(blocks["k_norm"]); + auto to_out_0 = std::dynamic_pointer_cast(blocks["to_out.0"]); + + auto q = to_q->forward(ctx, x); + auto k = to_k->forward(ctx, context); + auto v = to_v->forward(ctx, context); + + q = q_norm->forward(ctx, q); + k = k_norm->forward(ctx, k); + + if (pe != nullptr) { + if (k_pe == nullptr) { + k_pe = pe; + } + q = apply_hidden_rope(ctx->ggml_ctx, q, pe, heads, dim_head, rope_interleaved); + k = apply_hidden_rope(ctx->ggml_ctx, k, k_pe, heads, dim_head, rope_interleaved); + } + + auto out = ggml_ext_attention_ext(ctx->ggml_ctx, + ctx->backend, + q, + k, + v, + heads, + mask, + false, + ctx->flash_attn_enabled); + + if (blocks.count("to_gate_logits") > 0) { + auto to_gate_logits = std::dynamic_pointer_cast(blocks["to_gate_logits"]); + auto gate_logits = to_gate_logits->forward(ctx, x); + auto gates = ggml_sigmoid(ctx->ggml_ctx, gate_logits); + gates = ggml_ext_scale(ctx->ggml_ctx, gates, 2.0f, true); + gates = ggml_reshape_4d(ctx->ggml_ctx, gates, 1, heads, gate_logits->ne[1], gate_logits->ne[2]); + + auto out4 = ggml_reshape_4d(ctx->ggml_ctx, out, dim_head, heads, out->ne[1], out->ne[2]); + gates = ggml_repeat(ctx->ggml_ctx, gates, out4); + out4 = ggml_mul(ctx->ggml_ctx, out4, gates); + out = ggml_reshape_3d(ctx->ggml_ctx, out4, heads * dim_head, out4->ne[2], out4->ne[3]); + } + + return to_out_0->forward(ctx, out); + } + }; + + struct BasicTransformerBlock : public GGMLBlock { + int64_t dim; + bool cross_attention_adaln; + bool self_attention_gated; + bool cross_attention_gated; + + void init_params(ggml_context* ctx, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "") override { + ggml_type wtype = get_type(prefix + "scale_shift_table", tensor_storage_map, GGML_TYPE_F32); + params["scale_shift_table"] = ggml_new_tensor_2d(ctx, wtype, dim, cross_attention_adaln ? 9 : 6); + if (cross_attention_adaln) { + ggml_type prompt_wtype = get_type(prefix + "prompt_scale_shift_table", tensor_storage_map, GGML_TYPE_F32); + params["prompt_scale_shift_table"] = ggml_new_tensor_2d(ctx, prompt_wtype, dim, 2); + } + } + + BasicTransformerBlock(int64_t dim, + int64_t n_heads, + int64_t d_head, + int64_t context_dim, + bool rope_interleaved = true, + bool cross_attention_adaln = false, + bool self_attention_gated = false, + bool cross_attention_gated = false) + : dim(dim), + cross_attention_adaln(cross_attention_adaln), + self_attention_gated(self_attention_gated), + cross_attention_gated(cross_attention_gated) { + blocks["attn1"] = std::make_shared(dim, dim, n_heads, d_head, self_attention_gated, rope_interleaved); + blocks["attn2"] = std::make_shared(dim, context_dim, n_heads, d_head, cross_attention_gated, false); + blocks["ff"] = std::make_shared(dim, dim, 4, FeedForward::Activation::GELU); + } + + std::vector get_scale_shift_values(GGMLRunnerContext* ctx, + ggml_tensor* timestep) { + auto table = params["scale_shift_table"]; + int64_t batch = timestep->ne[1]; + + int64_t coeff = cross_attention_adaln ? 9 : 6; + auto t = ggml_reshape_3d(ctx->ggml_ctx, timestep, dim, coeff, batch); + auto s = ggml_reshape_3d(ctx->ggml_ctx, table, dim, coeff, 1); + auto e = ggml_new_tensor_3d(ctx->ggml_ctx, timestep->type, dim, coeff, batch); + s = ggml_repeat(ctx->ggml_ctx, s, e); + t = ggml_repeat(ctx->ggml_ctx, t, e); + auto out = ggml_add(ctx->ggml_ctx, s, t); + GGML_ASSERT(coeff <= INT_MAX); + return ggml_ext_chunk(ctx->ggml_ctx, out, static_cast(coeff), 1); + } + + std::vector get_prompt_scale_shift_values(GGMLRunnerContext* ctx, + ggml_tensor* prompt_timestep) { + auto table = params["prompt_scale_shift_table"]; + int64_t batch = prompt_timestep->ne[1]; + + auto t = ggml_reshape_3d(ctx->ggml_ctx, prompt_timestep, dim, 2, batch); + auto s = ggml_reshape_3d(ctx->ggml_ctx, table, dim, 2, 1); + auto e = ggml_new_tensor_3d(ctx->ggml_ctx, prompt_timestep->type, dim, 2, batch); + s = ggml_repeat(ctx->ggml_ctx, s, e); + t = ggml_repeat(ctx->ggml_ctx, t, e); + auto out = ggml_add(ctx->ggml_ctx, s, t); + return ggml_ext_chunk(ctx->ggml_ctx, out, 2, 1); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* context, + ggml_tensor* timestep, + ggml_tensor* prompt_timestep, + ggml_tensor* pe, + ggml_tensor* attention_mask = nullptr, + ggml_tensor* self_attention_mask = nullptr) { + auto attn1 = std::dynamic_pointer_cast(blocks["attn1"]); + auto attn2 = std::dynamic_pointer_cast(blocks["attn2"]); + auto ff = std::dynamic_pointer_cast(blocks["ff"]); + + auto mods = get_scale_shift_values(ctx, timestep); + auto shift_msa = mods[0]; + auto scale_msa = mods[1]; + auto gate_msa = mods[2]; + auto shift_mlp = mods[3]; + auto scale_mlp = mods[4]; + auto gate_mlp = mods[5]; + + auto x_norm = rms_norm(ctx->ggml_ctx, x); + x_norm = Flux::modulate(ctx->ggml_ctx, x_norm, shift_msa, scale_msa, true); + auto msa = attn1->forward(ctx, x_norm, nullptr, self_attention_mask, pe); + x = ggml_add(ctx->ggml_ctx, x, apply_gate(ctx->ggml_ctx, msa, gate_msa)); + + if (cross_attention_adaln) { + auto shift_q = mods[6]; + auto scale_q = mods[7]; + auto gate_q = mods[8]; + + auto q = rms_norm(ctx->ggml_ctx, x); + q = Flux::modulate(ctx->ggml_ctx, q, shift_q, scale_q, true); + + auto context_mod = context; + if (prompt_timestep != nullptr) { + auto prompt_mods = get_prompt_scale_shift_values(ctx, prompt_timestep); + context_mod = Flux::modulate(ctx->ggml_ctx, context_mod, prompt_mods[0], prompt_mods[1], true); + } + + auto mca = attn2->forward(ctx, q, context_mod, attention_mask, nullptr, nullptr); + x = ggml_add(ctx->ggml_ctx, x, apply_gate(ctx->ggml_ctx, mca, gate_q)); + } else { + auto mca = attn2->forward(ctx, x, context, attention_mask, nullptr, nullptr); + x = ggml_add(ctx->ggml_ctx, x, mca); + } + + auto y = rms_norm(ctx->ggml_ctx, x); + y = Flux::modulate(ctx->ggml_ctx, y, shift_mlp, scale_mlp, true); + auto mlp_out = ff->forward(ctx, y); + x = ggml_add(ctx->ggml_ctx, x, apply_gate(ctx->ggml_ctx, mlp_out, gate_mlp)); return x; } }; -}; + struct BasicTransformerBlock1D : public GGMLBlock { + BasicTransformerBlock1D(int64_t dim, + int64_t n_heads, + int64_t d_head, + bool rope_interleaved, + bool apply_gated_attention = false) { + blocks["attn1"] = std::make_shared(dim, dim, n_heads, d_head, apply_gated_attention, rope_interleaved); + blocks["ff"] = std::make_shared(dim, dim, 4, FeedForward::Activation::GELU); + } -#endif \ No newline at end of file + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* pe, + ggml_tensor* attention_mask = nullptr) { + auto attn1 = std::dynamic_pointer_cast(blocks["attn1"]); + auto ff = std::dynamic_pointer_cast(blocks["ff"]); + + auto h = rms_norm(ctx->ggml_ctx, x); + h = attn1->forward(ctx, h, nullptr, attention_mask, pe); + x = ggml_add(ctx->ggml_ctx, x, h); + + h = rms_norm(ctx->ggml_ctx, x); + h = ff->forward(ctx, h); + x = ggml_add(ctx->ggml_ctx, x, h); + return x; + } + }; + + struct Embeddings1DConnector : public GGMLBlock { + int64_t hidden_size; + int64_t num_attention_heads; + int64_t attention_head_dim; + int64_t num_layers; + int64_t num_learnable_registers; + bool rope_interleaved; + bool apply_gated_attention; + + void init_params(ggml_context* ctx, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "") override { + if (num_learnable_registers > 0) { + ggml_type wtype = get_type(prefix + "learnable_registers", tensor_storage_map, GGML_TYPE_F32); + params["learnable_registers"] = ggml_new_tensor_2d(ctx, wtype, hidden_size, num_learnable_registers); + } + } + + Embeddings1DConnector(int64_t hidden_size, + int64_t num_attention_heads = 30, + int64_t attention_head_dim = 128, + int64_t num_layers = 2, + int64_t num_learnable_registers = 128, + bool rope_interleaved = false, + bool apply_gated_attention = false) + : hidden_size(hidden_size), + num_attention_heads(num_attention_heads), + attention_head_dim(attention_head_dim), + num_layers(num_layers), + num_learnable_registers(num_learnable_registers), + rope_interleaved(rope_interleaved), + apply_gated_attention(apply_gated_attention) { + for (int i = 0; i < num_layers; i++) { + blocks["transformer_1d_blocks." + std::to_string(i)] = + std::make_shared(hidden_size, + num_attention_heads, + attention_head_dim, + rope_interleaved, + apply_gated_attention); + } + } + + ggml_tensor* append_registers(GGMLRunnerContext* ctx, + ggml_tensor* hidden_states) { + if (num_learnable_registers <= 0 || params.count("learnable_registers") == 0) { + return hidden_states; + } + + int64_t seq_len = hidden_states->ne[1]; + int64_t target_len = std::max(1024, seq_len); + int64_t duplications = (target_len + num_learnable_registers - 1) / num_learnable_registers; + int64_t total_to_keep = duplications * num_learnable_registers - seq_len; + if (total_to_keep <= 0) { + return hidden_states; + } + + auto regs = ggml_reshape_3d(ctx->ggml_ctx, params["learnable_registers"], hidden_size, num_learnable_registers, 1); + auto temp = ggml_new_tensor_3d(ctx->ggml_ctx, regs->type, regs->ne[0], regs->ne[1], hidden_states->ne[2]); + regs = ggml_repeat(ctx->ggml_ctx, regs, temp); + + auto regs_full = regs; + for (int64_t i = 1; i < duplications; i++) { + regs_full = ggml_concat(ctx->ggml_ctx, regs_full, regs, 1); + } + regs_full = ggml_ext_slice(ctx->ggml_ctx, regs_full, 1, seq_len, seq_len + total_to_keep); + return ggml_concat(ctx->ggml_ctx, hidden_states, regs_full, 1); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* hidden_states, + ggml_tensor* pe, + ggml_tensor* attention_mask = nullptr) { + hidden_states = append_registers(ctx, hidden_states); + + for (int i = 0; i < num_layers; i++) { + auto block = std::dynamic_pointer_cast(blocks["transformer_1d_blocks." + std::to_string(i)]); + hidden_states = block->forward(ctx, hidden_states, pe, attention_mask); + } + + return ggml_rms_norm(ctx->ggml_ctx, hidden_states, 1e-6f); + } + }; + + struct LTXAVParams { + int64_t in_channels = 128; + int64_t out_channels = 128; + int64_t hidden_size = 3840; + int64_t cross_attention_dim = 4096; + int64_t caption_channels = 3840; + int64_t num_attention_heads = 30; + int64_t attention_head_dim = 128; + int64_t num_layers = 28; + float positional_embedding_theta = 10000.f; + std::vector positional_embedding_max_pos = {20, 2048, 2048}; + std::tuple vae_scale_factors = {8, 32, 32}; + bool causal_temporal_positioning = true; + float timestep_scale_multiplier = 1000.f; + + int64_t audio_in_channels = 128; + int64_t audio_out_channels = 128; + int64_t audio_hidden_size = 2048; + int64_t audio_cross_attention_dim = 2048; + int64_t audio_num_attention_heads = 32; + int64_t audio_attention_head_dim = 64; + std::vector audio_positional_embedding_max_pos = {20}; + float av_ca_timestep_scale_multiplier = 1.f; + int64_t num_audio_channels = 8; + int64_t audio_frequency_bins = 16; + + bool use_connector = false; + int64_t connector_hidden_size = 3840; + int64_t connector_num_heads = 30; + int64_t connector_head_dim = 128; + int64_t connector_num_layers = 2; + int64_t connector_num_registers = 128; + bool connector_rope_interleaved = false; + bool connector_apply_gated_attention = false; + + bool use_audio_connector = false; + int64_t audio_connector_hidden_size = 2048; + int64_t audio_connector_num_heads = 32; + int64_t audio_connector_head_dim = 64; + int64_t audio_connector_num_layers = 2; + int64_t audio_connector_num_registers = 128; + bool audio_connector_rope_interleaved = false; + bool audio_connector_apply_gated_attention = false; + + bool video_rope_interleaved = false; + bool use_middle_indices_grid = true; + bool cross_attention_adaln = false; + + bool use_caption_projection = true; + bool use_audio_caption_projection = true; + bool caption_proj_before_connector = true; + bool caption_projection_first_linear = false; + + bool self_attention_gated = false; + bool cross_attention_gated = false; + }; + + __STATIC_INLINE__ std::pair infer_attention_layout(int64_t hidden_size, + int64_t preferred_heads = -1) { + if (preferred_heads > 0 && hidden_size % preferred_heads == 0) { + return {preferred_heads, hidden_size / preferred_heads}; + } + const int candidates[] = {128, 96, 80, 64, 48, 40, 32}; + for (int head_dim : candidates) { + if (hidden_size % head_dim == 0) { + int64_t heads = hidden_size / head_dim; + if (heads >= 8 && heads <= 64) { + return {heads, head_dim}; + } + } + } + return {32, hidden_size / 32}; + } + + __STATIC_INLINE__ std::vector build_1d_rope_matrix_from_coords(const std::vector& coords, + int dim, + int num_heads = 1, + float theta = 10000.f, + float max_pos = 20.f, + bool double_precision = false) { + GGML_ASSERT(dim % num_heads == 0); + const std::vector indices = double_precision ? std::vector() : generate_freq_grid(theta, 1, dim); + const std::vector indices_d = + double_precision ? generate_freq_grid_double(static_cast(theta), 1, dim) : std::vector(); + const int half_dim = dim / 2; + const int pad_size = half_dim - static_cast(double_precision ? indices_d.size() : indices.size()); + + std::vector> freqs(coords.size(), std::vector(half_dim, 0.f)); + for (size_t pos = 0; pos < coords.size(); pos++) { + int out_idx = 0; + for (int i = 0; i < pad_size; i++) { + freqs[pos][out_idx++] = 0.f; + } + if (double_precision) { + double coord = static_cast(coords[pos]) / static_cast(max_pos); + for (double index : indices_d) { + freqs[pos][out_idx++] = static_cast(index * (coord * 2.0 - 1.0)); + } + } else { + float coord = coords[pos] / max_pos; + for (float index : indices) { + freqs[pos][out_idx++] = index * (coord * 2.f - 1.f); + } + } + } + if (num_heads > 1) { + return build_rope_matrix_from_frequencies(split_frequencies_by_heads(freqs, dim, num_heads), dim / num_heads); + } + return build_rope_matrix_from_frequencies(freqs, dim); + } + + __STATIC_INLINE__ float video_latent_corner_to_time_sec(int64_t corner_index, + int scale_t, + float frame_rate, + bool causal_temporal_positioning) { + float pixel_t = static_cast(corner_index * scale_t); + if (causal_temporal_positioning) { + pixel_t = std::max(0.f, pixel_t + 1.f - scale_t); + } + return pixel_t / frame_rate; + } + + __STATIC_INLINE__ std::vector build_video_temporal_rope_matrix(int64_t width, + int64_t height, + int64_t frames, + int dim, + int num_heads, + float frame_rate, + float theta, + int max_pos_t, + int scale_t, + bool causal_temporal_positioning, + bool use_middle_indices_grid) { + std::vector coords; + coords.reserve(static_cast(width * height * frames)); + for (int64_t t = 0; t < frames; t++) { + float coord = video_latent_corner_to_time_sec(t, scale_t, frame_rate, causal_temporal_positioning); + if (use_middle_indices_grid) { + float end = video_latent_corner_to_time_sec(t + 1, scale_t, frame_rate, causal_temporal_positioning); + coord = 0.5f * (coord + end); + } + for (int64_t h = 0; h < height; h++) { + for (int64_t w = 0; w < width; w++) { + coords.push_back(coord); + } + } + } + return build_1d_rope_matrix_from_coords(coords, dim, num_heads, theta, static_cast(max_pos_t)); + } + + __STATIC_INLINE__ float audio_latent_start_time_sec(int64_t latent_index, + int audio_latent_downsample_factor = 4, + int hop_length = 160, + int sample_rate = 16000, + bool causal = true) { + float mel_frame = static_cast(latent_index * audio_latent_downsample_factor); + if (causal) { + mel_frame = std::max(0.f, mel_frame + 1.f - static_cast(audio_latent_downsample_factor)); + } + return mel_frame * static_cast(hop_length) / static_cast(sample_rate); + } + + __STATIC_INLINE__ std::vector build_audio_rope_matrix(int64_t seq_len, + int dim, + int num_heads, + float theta = 10000.f, + int max_pos_t = 20, + bool use_middle_indices_grid = false) { + std::vector coords(static_cast(seq_len), 0.f); + for (int64_t t = 0; t < seq_len; t++) { + float start = audio_latent_start_time_sec(t); + if (use_middle_indices_grid) { + float end = audio_latent_start_time_sec(t + 1); + coords[static_cast(t)] = 0.5f * (start + end); + } else { + coords[static_cast(t)] = start; + } + } + return build_1d_rope_matrix_from_coords(coords, dim, num_heads, theta, static_cast(max_pos_t)); + } + + struct BasicAVTransformerBlock : public GGMLBlock { + int64_t v_dim; + int64_t a_dim; + bool cross_attention_adaln; + + void init_params(ggml_context* ctx, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "") override { + int64_t coeff = cross_attention_adaln ? 9 : 6; + ggml_type vw = get_type(prefix + "scale_shift_table", tensor_storage_map, GGML_TYPE_F32); + ggml_type aw = get_type(prefix + "audio_scale_shift_table", tensor_storage_map, GGML_TYPE_F32); + params["scale_shift_table"] = ggml_new_tensor_2d(ctx, vw, v_dim, coeff); + params["audio_scale_shift_table"] = ggml_new_tensor_2d(ctx, aw, a_dim, coeff); + + if (cross_attention_adaln) { + ggml_type vpw = get_type(prefix + "prompt_scale_shift_table", tensor_storage_map, GGML_TYPE_F32); + ggml_type apw = get_type(prefix + "audio_prompt_scale_shift_table", tensor_storage_map, GGML_TYPE_F32); + params["prompt_scale_shift_table"] = ggml_new_tensor_2d(ctx, vpw, v_dim, 2); + params["audio_prompt_scale_shift_table"] = ggml_new_tensor_2d(ctx, apw, a_dim, 2); + } + + ggml_type avw = get_type(prefix + "scale_shift_table_a2v_ca_audio", tensor_storage_map, GGML_TYPE_F32); + ggml_type vaw = get_type(prefix + "scale_shift_table_a2v_ca_video", tensor_storage_map, GGML_TYPE_F32); + params["scale_shift_table_a2v_ca_audio"] = ggml_new_tensor_2d(ctx, avw, a_dim, 5); + params["scale_shift_table_a2v_ca_video"] = ggml_new_tensor_2d(ctx, vaw, v_dim, 5); + } + + BasicAVTransformerBlock(int64_t v_dim, + int64_t a_dim, + int64_t v_heads, + int64_t a_heads, + int64_t vd_head, + int64_t ad_head, + int64_t v_context_dim, + int64_t a_context_dim, + bool apply_gated_attention, + bool cross_attention_adaln, + bool video_rope_interleaved) + : v_dim(v_dim), + a_dim(a_dim), + cross_attention_adaln(cross_attention_adaln) { + blocks["attn1"] = std::make_shared(v_dim, v_dim, v_heads, vd_head, apply_gated_attention, video_rope_interleaved); + blocks["audio_attn1"] = std::make_shared(a_dim, a_dim, a_heads, ad_head, apply_gated_attention, false); + blocks["attn2"] = std::make_shared(v_dim, v_context_dim, v_heads, vd_head, apply_gated_attention, false); + blocks["audio_attn2"] = std::make_shared(a_dim, a_context_dim, a_heads, ad_head, apply_gated_attention, false); + blocks["audio_to_video_attn"] = std::make_shared(v_dim, a_dim, a_heads, ad_head, apply_gated_attention, false); + blocks["video_to_audio_attn"] = std::make_shared(a_dim, v_dim, a_heads, ad_head, apply_gated_attention, false); + blocks["ff"] = std::make_shared(v_dim, v_dim, 4, FeedForward::Activation::GELU); + blocks["audio_ff"] = std::make_shared(a_dim, a_dim, 4, FeedForward::Activation::GELU); + } + + std::vector get_ada_values(GGMLRunnerContext* ctx, + ggml_tensor* table, + ggml_tensor* timestep, + int64_t dim, + int64_t coeff, + int64_t start = 0, + int64_t count = -1) { + if (count < 0) { + count = coeff - start; + } + auto t = ggml_reshape_3d(ctx->ggml_ctx, timestep, dim, coeff, timestep->ne[1]); + auto s = ggml_reshape_3d(ctx->ggml_ctx, table, dim, coeff, 1); + auto e = ggml_new_tensor_3d(ctx->ggml_ctx, timestep->type, dim, coeff, timestep->ne[1]); + t = ggml_repeat(ctx->ggml_ctx, t, e); + s = ggml_repeat(ctx->ggml_ctx, s, e); + auto out = ggml_add(ctx->ggml_ctx, s, t); + GGML_ASSERT(coeff <= INT_MAX); + auto chunks = ggml_ext_chunk(ctx->ggml_ctx, out, static_cast(coeff), 1); + return std::vector(chunks.begin() + start, chunks.begin() + start + count); + } + + ggml_tensor* apply_text_cross_attention(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* context, + CrossAttention* attn, + ggml_tensor* table, + ggml_tensor* prompt_table, + ggml_tensor* timestep, + ggml_tensor* prompt_timestep, + int64_t dim, + ggml_tensor* attention_mask) { + if (cross_attention_adaln) { + auto q_mods = get_ada_values(ctx, table, timestep, dim, 9, 6, 3); + auto q = rms_norm(ctx->ggml_ctx, x); + q = Flux::modulate(ctx->ggml_ctx, q, q_mods[0], q_mods[1], true); + auto context_mod = context; + if (prompt_timestep != nullptr && prompt_table != nullptr) { + auto p_mods = get_ada_values(ctx, prompt_table, prompt_timestep, dim, 2); + context_mod = Flux::modulate(ctx->ggml_ctx, context_mod, p_mods[0], p_mods[1], true); + } + auto out = attn->forward(ctx, q, context_mod, attention_mask, nullptr, nullptr); + return apply_gate(ctx->ggml_ctx, out, q_mods[2]); + } + + auto q = rms_norm(ctx->ggml_ctx, x); + return attn->forward(ctx, q, context, attention_mask, nullptr, nullptr); + } + + std::pair forward(GGMLRunnerContext* ctx, + ggml_tensor* vx, + ggml_tensor* ax, + ggml_tensor* v_context, + ggml_tensor* a_context, + ggml_tensor* attention_mask, + ggml_tensor* v_timestep, + ggml_tensor* a_timestep, + ggml_tensor* v_pe, + ggml_tensor* a_pe, + ggml_tensor* v_cross_pe, + ggml_tensor* a_cross_pe, + ggml_tensor* v_cross_scale_shift_timestep, + ggml_tensor* a_cross_scale_shift_timestep, + ggml_tensor* v_cross_gate_timestep, + ggml_tensor* a_cross_gate_timestep, + ggml_tensor* v_prompt_timestep, + ggml_tensor* a_prompt_timestep, + ggml_tensor* self_attention_mask = nullptr) { + auto attn1 = std::dynamic_pointer_cast(blocks["attn1"]); + auto audio_attn1 = std::dynamic_pointer_cast(blocks["audio_attn1"]); + auto attn2 = std::dynamic_pointer_cast(blocks["attn2"]); + auto audio_attn2 = std::dynamic_pointer_cast(blocks["audio_attn2"]); + auto audio_to_video_attn = std::dynamic_pointer_cast(blocks["audio_to_video_attn"]); + auto video_to_audio_attn = std::dynamic_pointer_cast(blocks["video_to_audio_attn"]); + auto ff = std::dynamic_pointer_cast(blocks["ff"]); + auto audio_ff = std::dynamic_pointer_cast(blocks["audio_ff"]); + + auto v_table = params["scale_shift_table"]; + auto a_table = params["audio_scale_shift_table"]; + + bool run_ax = ax != nullptr && ggml_nelements(ax) > 0 && ax->ne[1] > 0; + bool run_a2v = run_ax; + bool run_v2a = run_ax; + + auto v_mods = get_ada_values(ctx, v_table, v_timestep, v_dim, cross_attention_adaln ? 9 : 6); + auto v_norm = rms_norm(ctx->ggml_ctx, vx); + v_norm = Flux::modulate(ctx->ggml_ctx, v_norm, v_mods[0], v_mods[1], true); + auto v_sa = attn1->forward(ctx, v_norm, nullptr, self_attention_mask, v_pe); + vx = ggml_add(ctx->ggml_ctx, vx, apply_gate(ctx->ggml_ctx, v_sa, v_mods[2])); + auto v_txt = apply_text_cross_attention(ctx, + vx, + v_context, + attn2.get(), + v_table, + cross_attention_adaln ? params["prompt_scale_shift_table"] : nullptr, + v_timestep, + v_prompt_timestep, + v_dim, + attention_mask); + vx = ggml_add(ctx->ggml_ctx, vx, v_txt); + + if (run_ax) { + auto a_mods = get_ada_values(ctx, a_table, a_timestep, a_dim, cross_attention_adaln ? 9 : 6); + auto a_norm = rms_norm(ctx->ggml_ctx, ax); + a_norm = Flux::modulate(ctx->ggml_ctx, a_norm, a_mods[0], a_mods[1], true); + auto a_sa = audio_attn1->forward(ctx, a_norm, nullptr, nullptr, a_pe); + ax = ggml_add(ctx->ggml_ctx, ax, apply_gate(ctx->ggml_ctx, a_sa, a_mods[2])); + auto a_txt = apply_text_cross_attention(ctx, + ax, + a_context, + audio_attn2.get(), + a_table, + cross_attention_adaln ? params["audio_prompt_scale_shift_table"] : nullptr, + a_timestep, + a_prompt_timestep, + a_dim, + attention_mask); + ax = ggml_add(ctx->ggml_ctx, ax, a_txt); + + auto vx_norm3 = rms_norm(ctx->ggml_ctx, vx); + auto ax_norm3 = rms_norm(ctx->ggml_ctx, ax); + + if (run_a2v) { + auto a2v_audio_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_audio"], 1, 0, 4); + auto a2v_video_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_video"], 1, 0, 4); + auto a2v_audio = get_ada_values(ctx, a2v_audio_table, a_cross_scale_shift_timestep, a_dim, 4); + auto a2v_video = get_ada_values(ctx, a2v_video_table, v_cross_scale_shift_timestep, v_dim, 4); + auto vx_scaled = Flux::modulate(ctx->ggml_ctx, vx_norm3, a2v_video[1], a2v_video[0], true); + auto ax_scaled = Flux::modulate(ctx->ggml_ctx, ax_norm3, a2v_audio[1], a2v_audio[0], true); + auto a2v_out = audio_to_video_attn->forward(ctx, vx_scaled, ax_scaled, nullptr, v_cross_pe, a_cross_pe); + auto a2v_gate_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_video"], 1, 4, 5); + auto a2v_gate = get_ada_values(ctx, a2v_gate_table, v_cross_gate_timestep, v_dim, 1)[0]; + vx = ggml_add(ctx->ggml_ctx, vx, apply_gate(ctx->ggml_ctx, a2v_out, a2v_gate)); + } + + if (run_v2a) { + auto v2a_audio_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_audio"], 1, 0, 4); + auto v2a_video_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_video"], 1, 0, 4); + auto v2a_audio = get_ada_values(ctx, v2a_audio_table, a_cross_scale_shift_timestep, a_dim, 4); + auto v2a_video = get_ada_values(ctx, v2a_video_table, v_cross_scale_shift_timestep, v_dim, 4); + auto ax_scaled = Flux::modulate(ctx->ggml_ctx, ax_norm3, v2a_audio[3], v2a_audio[2], true); + auto vx_scaled = Flux::modulate(ctx->ggml_ctx, vx_norm3, v2a_video[3], v2a_video[2], true); + auto v2a_out = video_to_audio_attn->forward(ctx, ax_scaled, vx_scaled, nullptr, a_cross_pe, v_cross_pe); + auto v2a_gate_table = ggml_ext_slice(ctx->ggml_ctx, params["scale_shift_table_a2v_ca_audio"], 1, 4, 5); + auto v2a_gate = get_ada_values(ctx, v2a_gate_table, a_cross_gate_timestep, a_dim, 1)[0]; + ax = ggml_add(ctx->ggml_ctx, ax, apply_gate(ctx->ggml_ctx, v2a_out, v2a_gate)); + } + + auto a_ff_mods = get_ada_values(ctx, a_table, a_timestep, a_dim, cross_attention_adaln ? 9 : 6, 3, 3); + auto ax_scaled = rms_norm(ctx->ggml_ctx, ax); + ax_scaled = Flux::modulate(ctx->ggml_ctx, ax_scaled, a_ff_mods[0], a_ff_mods[1], true); + auto a_ff_out = audio_ff->forward(ctx, ax_scaled); + ax = ggml_add(ctx->ggml_ctx, ax, apply_gate(ctx->ggml_ctx, a_ff_out, a_ff_mods[2])); + } + + auto v_ff_mods = get_ada_values(ctx, v_table, v_timestep, v_dim, cross_attention_adaln ? 9 : 6, 3, 3); + auto vx_scaled = rms_norm(ctx->ggml_ctx, vx); + vx_scaled = Flux::modulate(ctx->ggml_ctx, vx_scaled, v_ff_mods[0], v_ff_mods[1], true); + auto v_ff_out = ff->forward(ctx, vx_scaled); + vx = ggml_add(ctx->ggml_ctx, vx, apply_gate(ctx->ggml_ctx, v_ff_out, v_ff_mods[2])); + + return {vx, ax}; + } + }; + + struct LTXAVModelBlock : public GGMLBlock { + LTXAVParams cfg; + + void init_params(ggml_context* ctx, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "") override { + params["scale_shift_table"] = ggml_new_tensor_2d(ctx, + get_type(prefix + "scale_shift_table", tensor_storage_map, GGML_TYPE_F32), + cfg.hidden_size, + 2); + params["audio_scale_shift_table"] = ggml_new_tensor_2d(ctx, + get_type(prefix + "audio_scale_shift_table", tensor_storage_map, GGML_TYPE_F32), + cfg.audio_hidden_size, + 2); + } + + LTXAVModelBlock(const LTXAVParams& params) + : cfg(params) { + blocks["patchify_proj"] = std::make_shared(cfg.in_channels, cfg.hidden_size, true, true); + blocks["audio_patchify_proj"] = std::make_shared(cfg.audio_in_channels, cfg.audio_hidden_size, true, true); + blocks["adaln_single"] = std::make_shared(cfg.hidden_size, cfg.cross_attention_adaln ? 9 : 6); + blocks["audio_adaln_single"] = std::make_shared(cfg.audio_hidden_size, cfg.cross_attention_adaln ? 9 : 6); + if (cfg.cross_attention_adaln) { + blocks["prompt_adaln_single"] = std::make_shared(cfg.hidden_size, 2); + blocks["audio_prompt_adaln_single"] = std::make_shared(cfg.audio_hidden_size, 2); + } + blocks["av_ca_video_scale_shift_adaln_single"] = std::make_shared(cfg.hidden_size, 4); + blocks["av_ca_a2v_gate_adaln_single"] = std::make_shared(cfg.hidden_size, 1); + blocks["av_ca_audio_scale_shift_adaln_single"] = std::make_shared(cfg.audio_hidden_size, 4); + blocks["av_ca_v2a_gate_adaln_single"] = std::make_shared(cfg.audio_hidden_size, 1); + + if (cfg.use_caption_projection) { + if (cfg.caption_proj_before_connector) { + if (cfg.caption_projection_first_linear) { + blocks["caption_projection"] = std::make_shared(cfg.caption_channels, cfg.hidden_size); + } + } else { + blocks["caption_projection"] = std::make_shared(cfg.caption_channels, cfg.hidden_size, cfg.hidden_size); + } + } + if (cfg.use_audio_caption_projection) { + if (cfg.caption_proj_before_connector) { + if (cfg.caption_projection_first_linear) { + blocks["audio_caption_projection"] = std::make_shared(cfg.caption_channels, cfg.audio_hidden_size); + } + } else { + blocks["audio_caption_projection"] = std::make_shared(cfg.caption_channels, cfg.audio_hidden_size, cfg.audio_hidden_size); + } + } + + if (cfg.use_connector) { + blocks["video_embeddings_connector"] = std::make_shared(cfg.connector_hidden_size, + cfg.connector_num_heads, + cfg.connector_head_dim, + cfg.connector_num_layers, + cfg.connector_num_registers, + cfg.connector_rope_interleaved, + cfg.connector_apply_gated_attention); + } + if (cfg.use_audio_connector) { + blocks["audio_embeddings_connector"] = std::make_shared(cfg.audio_connector_hidden_size, + cfg.audio_connector_num_heads, + cfg.audio_connector_head_dim, + cfg.audio_connector_num_layers, + cfg.audio_connector_num_registers, + cfg.audio_connector_rope_interleaved, + cfg.audio_connector_apply_gated_attention); + } + + for (int i = 0; i < cfg.num_layers; i++) { + blocks["transformer_blocks." + std::to_string(i)] = std::make_shared(cfg.hidden_size, + cfg.audio_hidden_size, + cfg.num_attention_heads, + cfg.audio_num_attention_heads, + cfg.attention_head_dim, + cfg.audio_attention_head_dim, + cfg.cross_attention_dim, + cfg.audio_cross_attention_dim, + cfg.self_attention_gated || cfg.cross_attention_gated, + cfg.cross_attention_adaln, + cfg.video_rope_interleaved); + } + + blocks["norm_out"] = std::make_shared(cfg.hidden_size, 1e-6f, false); + blocks["proj_out"] = std::make_shared(cfg.hidden_size, cfg.out_channels, true, true); + blocks["audio_norm_out"] = std::make_shared(cfg.audio_hidden_size, 1e-6f, false); + blocks["audio_proj_out"] = std::make_shared(cfg.audio_hidden_size, cfg.audio_out_channels, true, true); + } + + ggml_tensor* patchify_video(GGMLRunnerContext* ctx, ggml_tensor* x, int64_t n) { + x = ggml_reshape_3d(ctx->ggml_ctx, x, x->ne[0] * x->ne[1] * x->ne[2], x->ne[3] / n, n); + x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); + return x; + } + + ggml_tensor* unpatchify_video(GGMLRunnerContext* ctx, + ggml_tensor* x, + int64_t width, + int64_t height, + int64_t frames) { + x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); + x = ggml_reshape_4d(ctx->ggml_ctx, x, width, height, frames, x->ne[1] * x->ne[2]); + return x; + } + + ggml_tensor* patchify_audio(GGMLRunnerContext* ctx, ggml_tensor* ax) { + ax = ggml_reshape_3d(ctx->ggml_ctx, ax, ax->ne[0] * ax->ne[2], ax->ne[1], ax->ne[3]); + return ax; + } + + ggml_tensor* unpatchify_audio(GGMLRunnerContext* ctx, ggml_tensor* ax, int64_t audio_length) { + if (ax == nullptr) { + return nullptr; + } + return ggml_reshape_4d(ctx->ggml_ctx, ax, cfg.audio_frequency_bins, audio_length, cfg.num_audio_channels, ax->ne[2]); + } + + std::pair preprocess_contexts(GGMLRunnerContext* ctx, + ggml_tensor* context, + ggml_tensor* video_connector_pe, + ggml_tensor* audio_connector_pe, + bool process_audio_context) { + if (context == nullptr) { + return {nullptr, nullptr}; + } + + bool is_fully_processed_context = + context->ne[0] == cfg.cross_attention_dim + cfg.audio_cross_attention_dim && + context->ne[1] >= 1024; + bool is_unprocessed_dual_context = + context->ne[0] == cfg.cross_attention_dim + cfg.audio_cross_attention_dim && + context->ne[1] < 1024; + + if (is_fully_processed_context) { + auto v_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, 0, cfg.cross_attention_dim); + ggml_tensor* a_context = nullptr; + if (process_audio_context) { + a_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, cfg.cross_attention_dim, cfg.cross_attention_dim + cfg.audio_cross_attention_dim); + } + return {v_context, a_context}; + } + + ggml_tensor* v_context = context; + ggml_tensor* a_context = process_audio_context ? context : nullptr; + if (is_unprocessed_dual_context) { + v_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, 0, cfg.cross_attention_dim); + if (process_audio_context) { + a_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, cfg.cross_attention_dim, cfg.cross_attention_dim + cfg.audio_cross_attention_dim); + } + } else if (context->ne[0] == cfg.caption_channels * 2) { + v_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, 0, cfg.caption_channels); + if (process_audio_context) { + a_context = ggml_ext_slice(ctx->ggml_ctx, context, 0, cfg.caption_channels, cfg.caption_channels * 2); + } + } + + if (cfg.caption_proj_before_connector) { + if (cfg.use_caption_projection && + blocks.count("caption_projection") > 0 && + v_context != nullptr && + v_context->ne[0] == cfg.caption_channels) { + auto caption_projection = std::dynamic_pointer_cast(blocks["caption_projection"]); + if (caption_projection != nullptr) { + v_context = caption_projection->forward(ctx, v_context); + } + } + if (process_audio_context && + cfg.use_audio_caption_projection && + blocks.count("audio_caption_projection") > 0 && + a_context != nullptr && + a_context->ne[0] == cfg.caption_channels) { + auto caption_projection = std::dynamic_pointer_cast(blocks["audio_caption_projection"]); + if (caption_projection != nullptr) { + a_context = caption_projection->forward(ctx, a_context); + } + } + } + + if (cfg.use_connector && v_context != nullptr && v_context->ne[0] == cfg.connector_hidden_size) { + auto connector = std::dynamic_pointer_cast(blocks["video_embeddings_connector"]); + v_context = connector->forward(ctx, v_context, video_connector_pe); + } + if (process_audio_context && + cfg.use_audio_connector && + a_context != nullptr && + a_context->ne[0] == cfg.audio_connector_hidden_size) { + auto connector = std::dynamic_pointer_cast(blocks["audio_embeddings_connector"]); + a_context = connector->forward(ctx, a_context, audio_connector_pe); + } + + if (!cfg.caption_proj_before_connector && + cfg.use_caption_projection && + blocks.count("caption_projection") > 0 && + v_context != nullptr && + v_context->ne[0] == cfg.caption_channels) { + auto caption_projection = std::dynamic_pointer_cast(blocks["caption_projection"]); + if (caption_projection != nullptr) { + v_context = caption_projection->forward(ctx, v_context); + } + } + if (process_audio_context && + !cfg.caption_proj_before_connector && + cfg.use_audio_caption_projection && + blocks.count("audio_caption_projection") > 0 && + a_context != nullptr && + a_context->ne[0] == cfg.caption_channels) { + auto caption_projection = std::dynamic_pointer_cast(blocks["audio_caption_projection"]); + if (caption_projection != nullptr) { + a_context = caption_projection->forward(ctx, a_context); + } + } + + return {v_context, a_context}; + } + + std::vector get_output_scale_shift(GGMLRunnerContext* ctx, + ggml_tensor* table, + ggml_tensor* embedded_timestep, + int64_t dim) { + auto temp = ggml_new_tensor_3d(ctx->ggml_ctx, embedded_timestep->type, dim, 2, embedded_timestep->ne[1]); + auto t = ggml_repeat(ctx->ggml_ctx, ggml_reshape_3d(ctx->ggml_ctx, embedded_timestep, dim, 1, embedded_timestep->ne[1]), temp); + auto s = ggml_repeat(ctx->ggml_ctx, ggml_reshape_3d(ctx->ggml_ctx, table, dim, 2, 1), temp); + auto out = ggml_add(ctx->ggml_ctx, s, t); + return ggml_ext_chunk(ctx->ggml_ctx, out, 2, 1); + } + + std::pair forward(GGMLRunnerContext* ctx, + ggml_tensor* vx, + ggml_tensor* ax, + ggml_tensor* timestep, + ggml_tensor* audio_timestep, + ggml_tensor* context, + ggml_tensor* v_pe, + ggml_tensor* a_pe, + ggml_tensor* v_cross_pe, + ggml_tensor* a_cross_pe, + ggml_tensor* video_connector_pe, + ggml_tensor* audio_connector_pe) { + auto patchify_proj = std::dynamic_pointer_cast(blocks["patchify_proj"]); + auto audio_patchify_proj = std::dynamic_pointer_cast(blocks["audio_patchify_proj"]); + auto adaln_single = std::dynamic_pointer_cast(blocks["adaln_single"]); + auto audio_adaln_single = std::dynamic_pointer_cast(blocks["audio_adaln_single"]); + auto norm_out = std::dynamic_pointer_cast(blocks["norm_out"]); + auto proj_out = std::dynamic_pointer_cast(blocks["proj_out"]); + auto audio_norm_out = std::dynamic_pointer_cast(blocks["audio_norm_out"]); + auto audio_proj_out = std::dynamic_pointer_cast(blocks["audio_proj_out"]); + + GGML_ASSERT(vx->ne[3] % cfg.in_channels == 0); + int64_t n = vx->ne[3] / cfg.in_channels; + int64_t width = vx->ne[0]; + int64_t height = vx->ne[1]; + int64_t frames = vx->ne[2]; + int64_t audio_time = ax != nullptr ? ax->ne[1] : 0; + + vx = patchify_video(ctx, vx, n); + vx = patchify_proj->forward(ctx, vx); + if (ax != nullptr && ggml_nelements(ax) > 0 && audio_time > 0) { + ax = patchify_audio(ctx, ax); + ax = audio_patchify_proj->forward(ctx, ax); + } else { + ax = nullptr; + } + + bool run_ax = ax != nullptr && ggml_nelements(ax) > 0 && audio_time > 0; + auto contexts = preprocess_contexts(ctx, context, video_connector_pe, audio_connector_pe, run_ax); + auto v_context = contexts.first; + auto a_context = contexts.second != nullptr ? contexts.second : contexts.first; + if (contexts.second != nullptr) { + a_context = ggml_cont(ctx->ggml_ctx, a_context); + } + + auto v_timestep_scaled = ggml_ext_scale(ctx->ggml_ctx, timestep, cfg.timestep_scale_multiplier); + auto v_pair = adaln_single->forward(ctx, v_timestep_scaled); + auto v_timestep_mod = v_pair.first; + auto v_embedded_time = v_pair.second; + + ggml_tensor* effective_audio_timestep = audio_timestep != nullptr ? audio_timestep : timestep; + auto a_timestep_scaled = ggml_ext_scale(ctx->ggml_ctx, effective_audio_timestep, cfg.timestep_scale_multiplier); + auto a_pair = audio_adaln_single->forward(ctx, a_timestep_scaled); + auto a_timestep_mod = a_pair.first; + auto a_embedded_time = a_pair.second; + + ggml_tensor* v_prompt_timestep_mod = nullptr; + ggml_tensor* a_prompt_timestep_mod = nullptr; + if (cfg.cross_attention_adaln) { + auto prompt_adaln_single = std::dynamic_pointer_cast(blocks["prompt_adaln_single"]); + auto audio_prompt_adaln_single = std::dynamic_pointer_cast(blocks["audio_prompt_adaln_single"]); + v_prompt_timestep_mod = prompt_adaln_single->forward(ctx, v_timestep_scaled).first; + a_prompt_timestep_mod = audio_prompt_adaln_single->forward(ctx, a_timestep_scaled).first; + } + + auto av_ca_video_scale_shift_timestep = + std::dynamic_pointer_cast(blocks["av_ca_video_scale_shift_adaln_single"])->forward(ctx, a_timestep_scaled).first; + auto av_ca_a2v_gate_noise_timestep = + std::dynamic_pointer_cast(blocks["av_ca_a2v_gate_adaln_single"]) + ->forward(ctx, ggml_ext_scale(ctx->ggml_ctx, a_timestep_scaled, cfg.av_ca_timestep_scale_multiplier / cfg.timestep_scale_multiplier)) + .first; + auto av_ca_audio_scale_shift_timestep = + std::dynamic_pointer_cast(blocks["av_ca_audio_scale_shift_adaln_single"])->forward(ctx, v_timestep_scaled).first; + auto av_ca_v2a_gate_noise_timestep = + std::dynamic_pointer_cast(blocks["av_ca_v2a_gate_adaln_single"]) + ->forward(ctx, ggml_ext_scale(ctx->ggml_ctx, v_timestep_scaled, cfg.av_ca_timestep_scale_multiplier / cfg.timestep_scale_multiplier)) + .first; + + for (int i = 0; i < cfg.num_layers; i++) { + auto block = std::dynamic_pointer_cast(blocks["transformer_blocks." + std::to_string(i)]); + auto out = block->forward(ctx, + vx, + ax, + v_context, + a_context, + nullptr, + v_timestep_mod, + a_timestep_mod, + v_pe, + a_pe, + v_cross_pe, + a_cross_pe, + av_ca_video_scale_shift_timestep, + av_ca_audio_scale_shift_timestep, + av_ca_a2v_gate_noise_timestep, + av_ca_v2a_gate_noise_timestep, + v_prompt_timestep_mod, + a_prompt_timestep_mod); + vx = out.first; + ax = out.second; + } + + auto v_shift_scale = get_output_scale_shift(ctx, params["scale_shift_table"], v_embedded_time, cfg.hidden_size); + vx = norm_out->forward(ctx, vx); + vx = Flux::modulate(ctx->ggml_ctx, vx, v_shift_scale[0], v_shift_scale[1], true); + vx = proj_out->forward(ctx, vx); + vx = unpatchify_video(ctx, vx, width, height, frames); + + if (ax != nullptr && audio_time > 0) { + auto a_shift_scale = get_output_scale_shift(ctx, params["audio_scale_shift_table"], a_embedded_time, cfg.audio_hidden_size); + ax = audio_norm_out->forward(ctx, ax); + ax = Flux::modulate(ctx->ggml_ctx, ax, a_shift_scale[0], a_shift_scale[1], true); + ax = audio_proj_out->forward(ctx, ax); + ax = unpatchify_audio(ctx, ax, audio_time); + } + + return {vx, ax}; + } + }; + + struct LTXAVRunner : public GGMLRunner { + std::string prefix; + LTXAVParams params; + LTXAVModelBlock model; + std::vector video_pe_vec; + std::vector audio_pe_vec; + std::vector video_cross_pe_vec; + std::vector audio_cross_pe_vec; + std::vector connector_pe_vec; + std::vector audio_connector_pe_vec; + sd::Tensor vx_input_cache; + sd::Tensor ax_input_cache; + + static int64_t infer_gate_heads(const String2TensorStorage& tensor_storage_map, + const std::string& bias_name, + int64_t fallback_heads) { + auto it = tensor_storage_map.find(bias_name); + if (it != tensor_storage_map.end()) { + return it->second.ne[0]; + } + return fallback_heads; + } + + LTXAVRunner(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map = {}, + const std::string& prefix = "model.diffusion_model") + : GGMLRunner(backend, offload_params_to_cpu), + prefix(prefix), + params(), + model(params) { + auto patchify_proj_iter = tensor_storage_map.find(prefix + ".patchify_proj.weight"); + if (patchify_proj_iter != tensor_storage_map.end()) { + params.in_channels = patchify_proj_iter->second.ne[0]; + params.hidden_size = patchify_proj_iter->second.ne[1]; + int64_t video_heads = infer_gate_heads(tensor_storage_map, prefix + ".transformer_blocks.0.attn1.to_gate_logits.bias", 32); + auto attn_layout = infer_attention_layout(params.hidden_size, video_heads); + params.num_attention_heads = attn_layout.first; + params.attention_head_dim = attn_layout.second; + } + + auto audio_patchify_proj_iter = tensor_storage_map.find(prefix + ".audio_patchify_proj.weight"); + if (audio_patchify_proj_iter != tensor_storage_map.end()) { + params.audio_in_channels = audio_patchify_proj_iter->second.ne[0]; + params.audio_hidden_size = audio_patchify_proj_iter->second.ne[1]; + params.audio_out_channels = params.audio_in_channels; + int64_t audio_heads = infer_gate_heads(tensor_storage_map, prefix + ".transformer_blocks.0.audio_attn1.to_gate_logits.bias", 32); + auto audio_attn_layout = infer_attention_layout(params.audio_hidden_size, audio_heads); + params.audio_num_attention_heads = audio_attn_layout.first; + params.audio_attention_head_dim = audio_attn_layout.second; + } + + auto proj_out_iter = tensor_storage_map.find(prefix + ".proj_out.weight"); + if (proj_out_iter != tensor_storage_map.end()) { + params.out_channels = proj_out_iter->second.ne[1]; + } + auto audio_proj_out_iter = tensor_storage_map.find(prefix + ".audio_proj_out.weight"); + if (audio_proj_out_iter != tensor_storage_map.end()) { + params.audio_out_channels = audio_proj_out_iter->second.ne[1]; + } + + auto attn2_iter = tensor_storage_map.find(prefix + ".transformer_blocks.0.attn2.to_k.weight"); + if (attn2_iter != tensor_storage_map.end()) { + params.cross_attention_dim = attn2_iter->second.ne[0]; + } + auto audio_attn2_iter = tensor_storage_map.find(prefix + ".transformer_blocks.0.audio_attn2.to_k.weight"); + if (audio_attn2_iter != tensor_storage_map.end()) { + params.audio_cross_attention_dim = audio_attn2_iter->second.ne[0]; + } + if (tensor_storage_map.find(prefix + ".transformer_blocks.0.prompt_scale_shift_table") != tensor_storage_map.end()) { + params.cross_attention_adaln = true; + } + if (tensor_storage_map.find(prefix + ".transformer_blocks.0.attn1.to_gate_logits.weight") != tensor_storage_map.end() || + tensor_storage_map.find(prefix + ".transformer_blocks.0.audio_attn1.to_gate_logits.weight") != tensor_storage_map.end()) { + params.self_attention_gated = true; + } + if (tensor_storage_map.find(prefix + ".transformer_blocks.0.attn2.to_gate_logits.weight") != tensor_storage_map.end() || + tensor_storage_map.find(prefix + ".transformer_blocks.0.audio_attn2.to_gate_logits.weight") != tensor_storage_map.end()) { + params.cross_attention_gated = true; + } + if (tensor_storage_map.find(prefix + ".caption_projection.linear_1.weight") == tensor_storage_map.end() && + tensor_storage_map.find(prefix + ".caption_projection.linear_2.weight") == tensor_storage_map.end()) { + params.use_caption_projection = false; + } + if (tensor_storage_map.find(prefix + ".audio_caption_projection.linear_1.weight") == tensor_storage_map.end() && + tensor_storage_map.find(prefix + ".audio_caption_projection.linear_2.weight") == tensor_storage_map.end()) { + params.use_audio_caption_projection = false; + } + + params.num_layers = count_prefix_blocks(tensor_storage_map, prefix + ".", "transformer_blocks."); + + auto connector_iter = tensor_storage_map.find(prefix + ".video_embeddings_connector.transformer_1d_blocks.0.attn1.to_q.weight"); + if (connector_iter != tensor_storage_map.end()) { + params.use_connector = true; + params.connector_hidden_size = connector_iter->second.ne[1]; + int64_t connector_heads = infer_gate_heads(tensor_storage_map, + prefix + ".video_embeddings_connector.transformer_1d_blocks.0.attn1.to_gate_logits.bias", + 32); + auto connector_layout = infer_attention_layout(params.connector_hidden_size, connector_heads); + params.connector_num_heads = connector_layout.first; + params.connector_head_dim = connector_layout.second; + params.connector_num_layers = count_prefix_blocks(tensor_storage_map, prefix + ".video_embeddings_connector.", "transformer_1d_blocks."); + auto register_iter = tensor_storage_map.find(prefix + ".video_embeddings_connector.learnable_registers"); + if (register_iter != tensor_storage_map.end()) { + params.connector_num_registers = register_iter->second.ne[1]; + } + if (tensor_storage_map.find(prefix + ".video_embeddings_connector.transformer_1d_blocks.0.attn1.to_gate_logits.weight") != tensor_storage_map.end()) { + params.connector_apply_gated_attention = true; + } + } + + auto audio_connector_iter = tensor_storage_map.find(prefix + ".audio_embeddings_connector.transformer_1d_blocks.0.attn1.to_q.weight"); + if (audio_connector_iter != tensor_storage_map.end()) { + params.use_audio_connector = true; + params.audio_connector_hidden_size = audio_connector_iter->second.ne[1]; + int64_t connector_heads = infer_gate_heads(tensor_storage_map, + prefix + ".audio_embeddings_connector.transformer_1d_blocks.0.attn1.to_gate_logits.bias", + 32); + auto connector_layout = infer_attention_layout(params.audio_connector_hidden_size, connector_heads); + params.audio_connector_num_heads = connector_layout.first; + params.audio_connector_head_dim = connector_layout.second; + params.audio_connector_num_layers = count_prefix_blocks(tensor_storage_map, prefix + ".audio_embeddings_connector.", "transformer_1d_blocks."); + auto register_iter = tensor_storage_map.find(prefix + ".audio_embeddings_connector.learnable_registers"); + if (register_iter != tensor_storage_map.end()) { + params.audio_connector_num_registers = register_iter->second.ne[1]; + } + if (tensor_storage_map.find(prefix + ".audio_embeddings_connector.transformer_1d_blocks.0.attn1.to_gate_logits.weight") != tensor_storage_map.end()) { + params.audio_connector_apply_gated_attention = true; + } + } + + model = LTXAVModelBlock(params); + model.init(params_ctx, tensor_storage_map, prefix); + } + + std::string get_desc() override { + return "ltxav"; + } + + void get_param_tensors(std::map& tensors, const std::string& prefix) { + model.get_param_tensors(tensors, prefix); + } + + std::pair, sd::Tensor> split_av_latents(const sd::Tensor& x_tensor, + int audio_length) const { + if (x_tensor.empty()) { + return {{}, {}}; + } + + GGML_ASSERT(x_tensor.dim() == 4 || x_tensor.dim() == 5); + if (x_tensor.dim() == 5) { + GGML_ASSERT(x_tensor.shape()[4] == 1); + } + int64_t width = x_tensor.shape()[0]; + int64_t height = x_tensor.shape()[1]; + int64_t frames = x_tensor.shape()[2]; + int64_t total_channels = x_tensor.shape()[3]; + int64_t spatial_size = width * height * frames; + + GGML_ASSERT(total_channels >= params.in_channels); + + sd::Tensor vx({width, height, frames, params.in_channels}); + size_t video_values = static_cast(params.in_channels * spatial_size); + std::copy_n(x_tensor.data(), video_values, vx.data()); + + if (audio_length <= 0 || total_channels == params.in_channels) { + return {vx, {}}; + } + + int64_t needed_audio_values = static_cast(audio_length) * params.num_audio_channels * params.audio_frequency_bins; + int64_t packed_audio_values = (total_channels - params.in_channels) * spatial_size; + GGML_ASSERT(packed_audio_values >= needed_audio_values); + + sd::Tensor ax({params.audio_frequency_bins, audio_length, params.num_audio_channels, 1}); + const float* audio_src = x_tensor.data() + video_values; + std::copy_n(audio_src, static_cast(needed_audio_values), ax.data()); + return {vx, ax}; + } + + ggml_tensor* merge_av_latents(ggml_context* ctx, + ggml_tensor* vx, + ggml_tensor* ax) const { + if (ax == nullptr || ggml_nelements(ax) == 0 || ax->ne[1] == 0) { + return vx; + } + + int64_t width = vx->ne[0]; + int64_t height = vx->ne[1]; + int64_t frames = vx->ne[2]; + int64_t divisor = width * height * frames; + int64_t audio_values = ax->ne[0] * ax->ne[1] * ax->ne[2] * ax->ne[3]; + int64_t pad_values = (divisor - (audio_values % divisor)) % divisor; + int64_t padded_len = audio_values + pad_values; + + ax = ggml_cont(ctx, ax); + ax = ggml_reshape_4d(ctx, ax, audio_values, 1, 1, 1); + if (pad_values > 0) { + ax = ggml_ext_pad(ctx, ax, static_cast(pad_values), 0, 0, 0); + } + int64_t extra_channels = padded_len / divisor; + ax = ggml_reshape_4d(ctx, ax, width, height, frames, extra_channels); + return ggml_concat(ctx, vx, ax, 3); + } + + ggml_cgraph* build_graph(const sd::Tensor& x_tensor, + const sd::Tensor& timesteps_tensor, + const sd::Tensor& context_tensor = {}, + const sd::Tensor& audio_x_tensor = {}, + const sd::Tensor& audio_timesteps_tensor = {}, + int audio_length = 0) { + auto split_inputs = split_av_latents(x_tensor, audio_length); + vx_input_cache = split_inputs.first; + if (!audio_x_tensor.empty()) { + ax_input_cache = audio_x_tensor; + } else { + ax_input_cache = split_inputs.second; + } + + ggml_tensor* vx = make_input(vx_input_cache); + ggml_tensor* ax = make_optional_input(ax_input_cache); + ggml_tensor* timesteps = make_input(timesteps_tensor); + ggml_tensor* a_timestep = make_optional_input(audio_timesteps_tensor); + ggml_tensor* context = make_optional_input(context_tensor); + + ggml_cgraph* gf = new_graph_custom(LTXAV_GRAPH_SIZE); + + video_pe_vec = build_video_rope_matrix(vx->ne[0], + vx->ne[1], + vx->ne[2], + static_cast(params.hidden_size), + static_cast(params.num_attention_heads), + 24.f, + params.positional_embedding_theta, + params.positional_embedding_max_pos, + params.vae_scale_factors, + params.causal_temporal_positioning, + params.use_middle_indices_grid); + auto video_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, params.attention_head_dim / 2, vx->ne[0] * vx->ne[1] * vx->ne[2] * params.num_attention_heads); + ggml_set_name(video_pe, "ltxav_video_pe"); + set_backend_tensor_data(video_pe, video_pe_vec.data()); + + ggml_tensor* audio_pe = nullptr; + ggml_tensor* video_cross_pe = nullptr; + ggml_tensor* audio_cross_pe = nullptr; + if (ax != nullptr && ggml_nelements(ax) > 0 && ax->ne[1] > 0) { + audio_pe_vec = build_audio_rope_matrix(ax->ne[1], + static_cast(params.audio_hidden_size), + static_cast(params.audio_num_attention_heads), + params.positional_embedding_theta, + params.audio_positional_embedding_max_pos[0], + params.use_middle_indices_grid); + audio_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, params.audio_attention_head_dim / 2, ax->ne[1] * params.audio_num_attention_heads); + ggml_set_name(audio_pe, "ltxav_audio_pe"); + set_backend_tensor_data(audio_pe, audio_pe_vec.data()); + + int temporal_max_pos = std::max(params.positional_embedding_max_pos[0], params.audio_positional_embedding_max_pos[0]); + video_cross_pe_vec = build_video_temporal_rope_matrix(vx->ne[0], + vx->ne[1], + vx->ne[2], + static_cast(params.audio_cross_attention_dim), + static_cast(params.audio_num_attention_heads), + 25.f, + params.positional_embedding_theta, + temporal_max_pos, + std::get<0>(params.vae_scale_factors), + params.causal_temporal_positioning, + true); + video_cross_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, params.audio_attention_head_dim / 2, vx->ne[0] * vx->ne[1] * vx->ne[2] * params.audio_num_attention_heads); + ggml_set_name(video_cross_pe, "ltxav_video_cross_pe"); + set_backend_tensor_data(video_cross_pe, video_cross_pe_vec.data()); + + audio_cross_pe_vec = build_audio_rope_matrix(ax->ne[1], + static_cast(params.audio_cross_attention_dim), + static_cast(params.audio_num_attention_heads), + params.positional_embedding_theta, + temporal_max_pos, + true); + audio_cross_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, params.audio_attention_head_dim / 2, ax->ne[1] * params.audio_num_attention_heads); + ggml_set_name(audio_cross_pe, "ltxav_audio_cross_pe"); + set_backend_tensor_data(audio_cross_pe, audio_cross_pe_vec.data()); + } + + bool needs_video_connector_pe = + params.use_connector && + context != nullptr && + (context->ne[0] == params.connector_hidden_size || + ((context->ne[0] == params.cross_attention_dim + params.audio_cross_attention_dim || + context->ne[0] == params.caption_channels * 2) && + context->ne[1] < 1024)); + ggml_tensor* video_connector_pe = nullptr; + if (needs_video_connector_pe) { + int64_t seq_len = context->ne[1]; + int64_t target_len = std::max(1024, seq_len); + int64_t duplications = (target_len + params.connector_num_registers - 1) / params.connector_num_registers; + int64_t full_len = seq_len + duplications * params.connector_num_registers - seq_len; + connector_pe_vec = build_1d_rope_matrix(full_len, static_cast(params.connector_hidden_size), static_cast(params.connector_num_heads), 10000.f, 4096.f, true); + video_connector_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, params.connector_head_dim / 2, full_len * params.connector_num_heads); + ggml_set_name(video_connector_pe, "ltxav_video_connector_pe"); + set_backend_tensor_data(video_connector_pe, connector_pe_vec.data()); + } + + bool run_audio_context = + ax != nullptr && + ggml_nelements(ax) > 0 && + ax->ne[1] > 0; + bool needs_audio_connector_pe = + run_audio_context && + params.use_audio_connector && + context != nullptr && + (context->ne[0] == params.audio_connector_hidden_size || + ((context->ne[0] == params.cross_attention_dim + params.audio_cross_attention_dim || + context->ne[0] == params.caption_channels * 2) && + context->ne[1] < 1024)); + ggml_tensor* audio_connector_pe = nullptr; + if (needs_audio_connector_pe) { + int64_t seq_len = context->ne[1]; + int64_t target_len = std::max(1024, seq_len); + int64_t duplications = (target_len + params.audio_connector_num_registers - 1) / params.audio_connector_num_registers; + int64_t full_len = seq_len + duplications * params.audio_connector_num_registers - seq_len; + audio_connector_pe_vec = build_1d_rope_matrix(full_len, static_cast(params.audio_connector_hidden_size), static_cast(params.audio_connector_num_heads), 10000.f, 4096.f, true); + audio_connector_pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, params.audio_connector_head_dim / 2, full_len * params.audio_connector_num_heads); + ggml_set_name(audio_connector_pe, "ltxav_audio_connector_pe"); + set_backend_tensor_data(audio_connector_pe, audio_connector_pe_vec.data()); + } + + auto runner_ctx = get_context(); + auto out_pair = model.forward(&runner_ctx, + vx, + ax, + timesteps, + a_timestep, + context, + video_pe, + audio_pe, + video_cross_pe, + audio_cross_pe, + video_connector_pe, + audio_connector_pe); + auto out = merge_av_latents(compute_ctx, out_pair.first, out_pair.second); + ggml_build_forward_expand(gf, out); + return gf; + } + + sd::Tensor compute(int n_threads, + const sd::Tensor& x, + const sd::Tensor& timesteps, + const sd::Tensor& context = {}, + const sd::Tensor& audio_x = {}, + const sd::Tensor& audio_timesteps = {}, + int audio_length = 0) { + auto get_graph = [&]() -> ggml_cgraph* { + return build_graph(x, timesteps, context, audio_x, audio_timesteps, audio_length); + }; + auto out = restore_trailing_singleton_dims(GGMLRunner::compute(get_graph, n_threads, false), x.dim()); + return out; + } + + void test(const std::string& x_path, + const std::string& timesteps_path = "", + const std::string& context_path = "", + const std::string& audio_x_path = "", + const std::string& audio_timesteps_path = "") { + auto x = sd::load_tensor_from_file_as_tensor(x_path); + GGML_ASSERT(!x.empty()); + print_sd_tensor(x, false, "ltxav_x"); + + sd::Tensor timesteps; + if (!timesteps_path.empty()) { + timesteps = sd::load_tensor_from_file_as_tensor(timesteps_path); + } else { + timesteps = sd::Tensor::from_vector(std::vector{1.f}); + } + GGML_ASSERT(!timesteps.empty()); + print_sd_tensor(timesteps, false, "ltxav_timesteps"); + + sd::Tensor context; + if (!context_path.empty()) { + context = sd::load_tensor_from_file_as_tensor(context_path); + GGML_ASSERT(!context.empty()); + print_sd_tensor(context, false, "ltxav_context"); + } + + sd::Tensor audio_x; + int audio_length = 0; + if (!audio_x_path.empty()) { + audio_x = sd::load_tensor_from_file_as_tensor(audio_x_path); + GGML_ASSERT(!audio_x.empty()); + GGML_ASSERT(audio_x.dim() >= 2); + audio_length = static_cast(audio_x.shape()[1]); + print_sd_tensor(audio_x, false, "ltxav_audio_x"); + } + + sd::Tensor audio_timesteps; + if (!audio_timesteps_path.empty()) { + audio_timesteps = sd::load_tensor_from_file_as_tensor(audio_timesteps_path); + GGML_ASSERT(!audio_timesteps.empty()); + } else if (!audio_x.empty()) { + audio_timesteps = timesteps; + } + if (!audio_timesteps.empty()) { + print_sd_tensor(audio_timesteps, false, "ltxav_audio_timesteps"); + } + + int64_t t0 = ggml_time_ms(); + auto out_opt = compute(8, x, timesteps, context, audio_x, audio_timesteps, audio_length); + int64_t t1 = ggml_time_ms(); + + GGML_ASSERT(!out_opt.empty()); + print_sd_tensor(out_opt, false, "ltxav_out"); + LOG_DEBUG("ltxav test done in %lldms", t1 - t0); + } + + static void load_from_file_and_test(const std::string& model_path, + const std::string& x_path, + const std::string& timesteps_path = "", + const std::string& context_path = "", + const std::string& embeddings_path = "", + const std::string& audio_x_path = "", + const std::string& audio_timesteps_path = "") { + // ggml_backend_t backend = ggml_backend_cuda_init(0); + ggml_backend_t backend = ggml_backend_cpu_init(); + LOG_INFO("loading ltxav from '%s'", model_path.c_str()); + + ModelLoader model_loader; + if (!model_loader.init_from_file_and_convert_name(model_path, "model.diffusion_model.")) { + LOG_ERROR("init model loader from file failed: '%s'", model_path.c_str()); + return; + } + if (!embeddings_path.empty()) { + LOG_INFO("loading ltxav embeddings from '%s'", embeddings_path.c_str()); + if (!model_loader.init_from_file(embeddings_path)) { + LOG_ERROR("init embeddings model loader from file failed: '%s'", embeddings_path.c_str()); + return; + } + } + + auto& tensor_storage_map = model_loader.get_tensor_storage_map(); + std::shared_ptr ltxav = std::make_shared(backend, + false, + tensor_storage_map, + "model.diffusion_model"); + + ltxav->alloc_params_buffer(); + std::map tensors; + ltxav->get_param_tensors(tensors, "model.diffusion_model"); + + if (!model_loader.load_tensors(tensors)) { + LOG_ERROR("load tensors from model loader failed"); + return; + } + + LOG_INFO("ltxav model loaded"); + ltxav->test(x_path, timesteps_path, context_path, audio_x_path, audio_timesteps_path); + } + }; + +}; // namespace LTXV + +#endif diff --git a/src/model.cpp b/src/model.cpp index 3479a0be..3ed38a57 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -471,6 +471,9 @@ SDVersion ModelLoader::get_sd_version() { if (tensor_storage.name.find("model.diffusion_model.layers.0.adaLN_sa_ln.weight") != std::string::npos) { return VERSION_ERNIE_IMAGE; } + if (tensor_storage.name.find("model.diffusion_model.adaln_single.emb.timestep_embedder.linear_1.bias") != std::string::npos) { + return VERSION_LTXAV; + } if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) { is_wan = true; } diff --git a/src/model.h b/src/model.h index 65bc6c36..27c873df 100644 --- a/src/model.h +++ b/src/model.h @@ -42,6 +42,7 @@ enum SDVersion { VERSION_ANIMA, VERSION_FLUX2, VERSION_FLUX2_KLEIN, + VERSION_LTXAV, VERSION_Z_IMAGE, VERSION_OVIS_IMAGE, VERSION_ERNIE_IMAGE, @@ -104,6 +105,13 @@ static inline bool sd_version_is_flux2(SDVersion version) { return false; } +static inline bool sd_version_is_ltxav(SDVersion version) { + if (version == VERSION_LTXAV) { + return true; + } + return false; +} + static inline bool sd_version_is_wan(SDVersion version) { if (version == VERSION_WAN2 || version == VERSION_WAN2_2_I2V || version == VERSION_WAN2_2_TI2V) { return true; @@ -160,6 +168,7 @@ static inline bool sd_version_is_inpaint(SDVersion version) { static inline bool sd_version_is_dit(SDVersion version) { if (sd_version_is_flux(version) || sd_version_is_flux2(version) || + sd_version_is_ltxav(version) || sd_version_is_sd3(version) || sd_version_is_wan(version) || sd_version_is_qwen_image(version) || diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index c6541148..bab6fee9 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -14,6 +14,7 @@ #include "diffusion_model.hpp" #include "esrgan.hpp" #include "lora.hpp" +#include "ltx_vae.hpp" #include "pmid.hpp" #include "sample-cache.h" #include "tae.hpp" @@ -52,6 +53,7 @@ const char* model_version_to_str[] = { "Anima", "Flux.2", "Flux.2 klein", + "LTXAV", "Z-Image", "Ovis Image", "Ernie Image", @@ -351,6 +353,17 @@ public: return false; } + if (strlen(SAFE_STR(sd_ctx_params->embeddings_connectors_path)) > 0) { + if (sd_version_is_ltxav(version)) { + LOG_INFO("loading embeddings connectors from '%s'", sd_ctx_params->embeddings_connectors_path); + if (!model_loader.init_from_file(sd_ctx_params->embeddings_connectors_path)) { + LOG_WARN("loading embeddings connectors from '%s' failed", sd_ctx_params->embeddings_connectors_path); + } + } else { + LOG_WARN("ignoring embeddings connectors for non-LTXAV model: '%s'", sd_ctx_params->embeddings_connectors_path); + } + } + auto& tensor_storage_map = model_loader.get_tensor_storage_map(); LOG_INFO("Version: %s ", model_version_to_str[version]); @@ -415,6 +428,9 @@ public: // Might need vae encode for control cond vae_decode_only = false; } + if (sd_version_is_ltxav(version)) { + vae_decode_only = true; + } bool tae_preview_only = sd_ctx_params->tae_preview_only; if (version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) { @@ -492,6 +508,14 @@ public: tensor_storage_map, version, sd_ctx_params->chroma_use_dit_mask); + } else if (sd_version_is_ltxav(version)) { + cond_stage_model = std::make_shared(clip_backend, + offload_params_to_cpu, + tensor_storage_map); + diffusion_model = std::make_shared(backend, + offload_params_to_cpu, + tensor_storage_map, + "model.diffusion_model"); } else if (sd_version_is_wan(version)) { cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, @@ -638,9 +662,16 @@ public: }; auto create_vae = [&]() -> std::shared_ptr { - if (sd_version_is_wan(version) || - sd_version_is_qwen_image(version) || - sd_version_is_anima(version)) { + if (sd_version_is_ltxav(version)) { + return std::make_shared(vae_backend, + offload_params_to_cpu, + tensor_storage_map, + "first_stage_model", + true, + version); + } else if (sd_version_is_wan(version) || + sd_version_is_qwen_image(version) || + sd_version_is_anima(version)) { return std::make_shared(vae_backend, offload_params_to_cpu, tensor_storage_map, @@ -936,13 +967,16 @@ public: pred_type = EPS_PRED; } } else if (sd_version_is_sd3(version) || + sd_version_is_ltxav(version) || sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version) || sd_version_is_ernie_image(version) || sd_version_is_z_image(version)) { pred_type = FLOW_PRED; - if (sd_version_is_wan(version)) { + if (sd_version_is_ltxav(version)) { + default_flow_shift = 2.37f; + } else if (sd_version_is_wan(version)) { default_flow_shift = 5.f; } else if (sd_version_is_ernie_image(version)) { default_flow_shift = 4.f; @@ -979,8 +1013,13 @@ public: denoiser = std::make_shared(); break; case FLOW_PRED: { - LOG_INFO("running in FLOW mode"); - denoiser = std::make_shared(); + if (sd_version_is_ltxav(version)) { + LOG_INFO("running in LTXAV FLOW mode"); + denoiser = std::make_shared(); + } else { + LOG_INFO("running in FLOW mode"); + denoiser = std::make_shared(); + } break; } case FLUX_FLOW_PRED: { @@ -1621,6 +1660,7 @@ public: const sd::Tensor& denoise_mask, const sd::Tensor& vace_context, float vace_strength, + int audio_length, const sd_cache_params_t* cache_params) { std::vector skip_layers(guidance.slg.layers, guidance.slg.layers + guidance.slg.layer_count); float cfg_scale = guidance.txt_cfg; @@ -1699,6 +1739,7 @@ public: diffusion_params.control_strength = control_strength; diffusion_params.vace_context = vace_context.empty() ? nullptr : &vace_context; diffusion_params.vace_strength = vace_strength; + diffusion_params.audio_length = audio_length; diffusion_params.skip_layers = nullptr; compute_sample_controls(control_image, @@ -1860,7 +1901,9 @@ public: int get_latent_channel() { int latent_channel = 4; if (sd_version_is_dit(version)) { - if (version == VERSION_WAN2_2_TI2V) { + if (sd_version_is_ltxav(version)) { + latent_channel = 128; + } else if (version == VERSION_WAN2_2_TI2V) { latent_channel = 48; } else if (version == VERSION_CHROMA_RADIANCE) { latent_channel = 3; @@ -1886,7 +1929,9 @@ public: int W = width / vae_scale_factor; int H = height / vae_scale_factor; int T = frames; - if (sd_version_is_wan(version)) { + if (sd_version_is_ltxav(version)) { + T = ((T - 1) / 8) + 1; + } else if (sd_version_is_wan(version)) { T = ((T - 1) / 4) + 1; } int C = get_latent_channel(); @@ -2223,6 +2268,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { "llm_vision_path: %s\n" "diffusion_model_path: %s\n" "high_noise_diffusion_model_path: %s\n" + "embeddings_connectors_path: %s\n" "vae_path: %s\n" "taesd_path: %s\n" "control_net_path: %s\n" @@ -2255,6 +2301,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { SAFE_STR(sd_ctx_params->llm_vision_path), SAFE_STR(sd_ctx_params->diffusion_model_path), SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path), + SAFE_STR(sd_ctx_params->embeddings_connectors_path), SAFE_STR(sd_ctx_params->vae_path), SAFE_STR(sd_ctx_params->taesd_path), SAFE_STR(sd_ctx_params->control_net_path), @@ -2433,6 +2480,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) { sd_vid_gen_params->strength = 0.75f; sd_vid_gen_params->seed = -1; sd_vid_gen_params->video_frames = 6; + sd_vid_gen_params->fps = 16; sd_vid_gen_params->moe_boundary = 0.875f; sd_vid_gen_params->vace_strength = 1.f; sd_vid_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f}; @@ -2444,7 +2492,7 @@ struct sd_ctx_t { }; static bool sd_version_supports_video_generation(SDVersion version) { - return version == VERSION_SVD || sd_version_is_wan(version); + return version == VERSION_SVD || sd_version_is_wan(version) || sd_version_is_ltxav(version); } static bool sd_version_supports_image_generation(SDVersion version) { @@ -2589,6 +2637,8 @@ struct GenerationRequest { sd_pm_params_t pm_params = {}; sd_hires_params_t hires = {}; int frames = -1; + int requested_frames = -1; + int fps = 16; float vace_strength = 1.f; GenerationRequest(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params) { @@ -2615,12 +2665,18 @@ struct GenerationRequest { } GenerationRequest(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params) { - prompt = SAFE_STR(sd_vid_gen_params->prompt); - negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt); - width = sd_vid_gen_params->width; - height = sd_vid_gen_params->height; - frames = (sd_vid_gen_params->video_frames - 1) / 4 * 4 + 1; + prompt = SAFE_STR(sd_vid_gen_params->prompt); + negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt); + width = sd_vid_gen_params->width; + height = sd_vid_gen_params->height; + requested_frames = std::max(1, sd_vid_gen_params->video_frames); + if (sd_version_is_ltxav(sd_ctx->sd->version)) { + frames = ((requested_frames - 1 + 7) / 8) * 8 + 1; + } else { + frames = (requested_frames - 1) / 4 * 4 + 1; + } clip_skip = sd_vid_gen_params->clip_skip; + fps = std::max(1, sd_vid_gen_params->fps); vae_scale_factor = sd_ctx->sd->get_vae_scale_factor(); diffusion_model_down_factor = sd_ctx->sd->get_diffusion_model_down_factor(); seed = sd_vid_gen_params->seed; @@ -2629,6 +2685,12 @@ struct GenerationRequest { guidance = sd_vid_gen_params->sample_params.guidance; high_noise_guidance = sd_vid_gen_params->high_noise_sample_params.guidance; resolve(sd_ctx); + if (frames != requested_frames) { + LOG_WARN("align video frames from %d to %d for %s", + requested_frames, + frames, + model_version_to_str[sd_ctx->sd->version]); + } } void align_generation_request_size() { @@ -2858,6 +2920,7 @@ struct ImageGenerationLatents { sd::Tensor init_latent; sd::Tensor concat_latent; sd::Tensor uncond_concat_latent; + sd::Tensor audio_latent; sd::Tensor control_image; std::vector> ref_images; std::vector> ref_latents; @@ -2865,8 +2928,51 @@ struct ImageGenerationLatents { sd::Tensor clip_vision_output; sd::Tensor vace_context; int64_t ref_image_num = 0; + int audio_length = 0; }; +static sd::Tensor pack_ltxav_audio_and_video_latents(const sd::Tensor& video_latent, + const sd::Tensor& audio_latent) { + if (audio_latent.empty()) { + return video_latent; + } + + GGML_ASSERT(video_latent.dim() == 4 || video_latent.dim() == 5); + GGML_ASSERT(audio_latent.dim() == 3 || audio_latent.dim() == 4); + if (video_latent.dim() == 5) { + GGML_ASSERT(video_latent.shape()[4] == 1); + } + if (audio_latent.dim() == 4) { + GGML_ASSERT(audio_latent.shape()[3] == 1); + } + + int64_t width = video_latent.shape()[0]; + int64_t height = video_latent.shape()[1]; + int64_t frames = video_latent.shape()[2]; + int64_t video_ch = video_latent.shape()[3]; + int64_t spatial_size = width * height * frames; + int64_t audio_values = audio_latent.numel(); + int64_t extra_ch = (audio_values + spatial_size - 1) / spatial_size; + + std::vector packed_shape = video_latent.shape(); + packed_shape[3] = video_ch + extra_ch; + sd::Tensor packed = sd::zeros(packed_shape); + + std::copy_n(video_latent.data(), video_latent.numel(), packed.data()); + std::copy_n(audio_latent.data(), audio_latent.numel(), packed.data() + video_latent.numel()); + return packed; +} + +static int get_ltxav_num_audio_latents(int frames, int fps) { + GGML_ASSERT(frames > 0); + GGML_ASSERT(fps > 0); + constexpr float kSampleRate = 16000.0f; + constexpr float kMelHopLength = 160.0f; + constexpr float kAudioLatentDownsample = 4.0f; + constexpr float kLatentsPerSecond = kSampleRate / kMelHopLength / kAudioLatentDownsample; + return static_cast(std::ceil((static_cast(frames) / static_cast(fps)) * kLatentsPerSecond)); +} + struct ImageGenerationEmbeds { SDCondition cond; SDCondition uncond; @@ -3454,6 +3560,7 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s latents.denoise_mask, sd::Tensor(), 1.f, + 0, request.cache_params); int64_t sampling_end = ggml_time_ms(); if (!x_0.empty()) { @@ -3575,6 +3682,7 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s hires_denoise_mask, sd::Tensor(), 1.f, + 0, request.cache_params); int64_t hires_sample_end = ggml_time_ms(); if (!x_0.empty()) { @@ -3633,6 +3741,18 @@ static std::optional prepare_video_generation_latents(sd end_image = sd_image_to_tensor(sd_vid_gen_params->end_image, request->width, request->height); } + if (sd_version_is_ltxav(sd_ctx->sd->version)) { + latents.audio_length = 0; + latents.audio_latent = {}; + } + + if (sd_version_is_ltxav(sd_ctx->sd->version)) { + if (!start_image.empty() || !end_image.empty() || sd_vid_gen_params->control_frames_size > 0) { + LOG_ERROR("LTXAV currently supports txt2vid only; init_image, end_image, and control_frames are not implemented"); + return std::nullopt; + } + } + if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" || sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-I2V-14B" || sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-1.3B" || @@ -3803,6 +3923,9 @@ static std::optional prepare_video_generation_latents(sd latents.init_latent = sd_ctx->sd->generate_init_latent(request->width, request->height, request->frames, true); } + // Pipeline-level audio support is temporarily disabled. Keep the model-side + // AV implementation intact, but feed pure video latents through vid_gen. + return latents; } @@ -3839,14 +3962,26 @@ static ImageGenerationEmbeds prepare_video_generation_embeds(sd_ctx_t* sd_ctx, } static sd_image_t* decode_video_outputs(sd_ctx_t* sd_ctx, + const GenerationRequest& request, const sd::Tensor& final_latent, int* num_frames_out) { if (final_latent.empty()) { LOG_ERROR("no latent video to decode"); return nullptr; } + sd::Tensor video_latent = final_latent; + if (sd_version_is_ltxav(sd_ctx->sd->version) && + video_latent.shape()[3] > sd_ctx->sd->get_latent_channel()) { + video_latent = sd::ops::slice(video_latent, 3, 0, sd_ctx->sd->get_latent_channel()); + } + LOG_DEBUG("decode_video_outputs latent %dx%dx%dx%d", + (int)video_latent.shape()[0], + (int)video_latent.shape()[1], + (int)video_latent.shape()[2], + (int)video_latent.shape()[3]); + // auto z = sd::load_tensor_from_file_as_tensor("ltx_vae_z.bin"); int64_t t4 = ggml_time_ms(); - sd::Tensor vid = sd_ctx->sd->decode_first_stage(final_latent, true); + sd::Tensor vid = sd_ctx->sd->decode_first_stage(video_latent, true); int64_t t5 = ggml_time_ms(); LOG_INFO("decode_first_stage completed, taking %.2fs", (t5 - t4) * 1.0f / 1000); if (sd_ctx->sd->free_params_immediately) { @@ -3856,6 +3991,15 @@ static sd_image_t* decode_video_outputs(sd_ctx_t* sd_ctx, LOG_ERROR("decode_first_stage failed for video"); return nullptr; } + LOG_DEBUG("decode_video_outputs decoded %dx%dx%dx%d", + (int)vid.shape()[0], + (int)vid.shape()[1], + (int)vid.shape()[2], + (int)vid.shape()[3]); + if (request.requested_frames > 0 && + vid.shape()[2] > request.requested_frames) { + vid = sd::ops::slice(vid, 2, 0, request.requested_frames); + } sd_image_t* result_images = (sd_image_t*)calloc(vid.shape()[2], sizeof(sd_image_t)); if (result_images == nullptr) { @@ -3939,6 +4083,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s latents.denoise_mask, latents.vace_context, request.vace_strength, + latents.audio_length, request.cache_params); int64_t sampling_end = ggml_time_ms(); if (x_t_sampled.empty()) { @@ -3981,6 +4126,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s latents.denoise_mask, latents.vace_context, request.vace_strength, + latents.audio_length, request.cache_params); int64_t sampling_end = ggml_time_ms(); @@ -4000,7 +4146,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s int64_t latent_end = ggml_time_ms(); LOG_INFO("generating latent video completed, taking %.2fs", (latent_end - latent_start) * 1.0f / 1000); - auto result = decode_video_outputs(sd_ctx, final_latent, num_frames_out); + auto result = decode_video_outputs(sd_ctx, request, final_latent, num_frames_out); if (result == nullptr) { return nullptr; } diff --git a/src/tae.hpp b/src/tae.hpp index 0a0ca682..41b53515 100644 --- a/src/tae.hpp +++ b/src/tae.hpp @@ -2,7 +2,6 @@ #define __TAE_HPP__ #include "ggml_extend.hpp" - #include "model.h" /* diff --git a/src/tensor_ggml.hpp b/src/tensor_ggml.hpp index 493a958c..c6e9d4ac 100644 --- a/src/tensor_ggml.hpp +++ b/src/tensor_ggml.hpp @@ -104,7 +104,7 @@ namespace sd { throw std::invalid_argument("tensor file type does not match requested sd::Tensor type"); } - std::vector shape(4, 1); + std::vector shape(n_dims, 1); for (int i = 0; i < n_dims; ++i) { int32_t dim = 1; file.read(reinterpret_cast(&dim), sizeof(dim)); diff --git a/src/tokenizers/gemma_tokenizer.cpp b/src/tokenizers/gemma_tokenizer.cpp index 76880150..dd9026eb 100644 --- a/src/tokenizers/gemma_tokenizer.cpp +++ b/src/tokenizers/gemma_tokenizer.cpp @@ -50,6 +50,7 @@ GemmaTokenizer::GemmaTokenizer(const std::string& merges_utf8_str, const std::st byte_level_bpe = false; byte_fallback = true; add_bos_token = true; + pad_left = true; PAD_TOKEN = ""; EOS_TOKEN = ""; BOS_TOKEN = ""; diff --git a/src/vae.hpp b/src/vae.hpp index dc69535e..bf4966e5 100644 --- a/src/vae.hpp +++ b/src/vae.hpp @@ -67,7 +67,9 @@ public: int get_scale_factor() { int scale_factor = 8; - if (version == VERSION_WAN2_2_TI2V) { + if (version == VERSION_LTXAV) { + scale_factor = 32; + } else if (version == VERSION_WAN2_2_TI2V) { scale_factor = 16; } else if (sd_version_uses_flux2_vae(version)) { scale_factor = 16; diff --git a/src/wan.hpp b/src/wan.hpp index 6860262c..cdefec26 100644 --- a/src/wan.hpp +++ b/src/wan.hpp @@ -966,10 +966,10 @@ namespace WAN { blocks["conv2"] = std::shared_ptr(new CausalConv3d(z_dim, z_dim, {1, 1, 1})); } - ggml_tensor* patchify(ggml_context* ctx, - ggml_tensor* x, - int64_t patch_size, - int64_t b = 1) { + static ggml_tensor* patchify(ggml_context* ctx, + ggml_tensor* x, + int64_t patch_size, + int64_t b = 1) { // x: [b*c, f, h*q, w*r] // return: [b*c*r*q, f, h, w] if (patch_size == 1) { @@ -993,10 +993,10 @@ namespace WAN { return x; } - ggml_tensor* unpatchify(ggml_context* ctx, - ggml_tensor* x, - int64_t patch_size, - int64_t b = 1) { + static ggml_tensor* unpatchify(ggml_context* ctx, + ggml_tensor* x, + int64_t patch_size, + int64_t b = 1) { // x: [b*c*r*q, f, h, w] // return: [b*c, f, h*q, w*r] if (patch_size == 1) {